Giới thiệu về cây Chủ tịch (Chairman Tree)

Tôi ghét các cấu trúc dữ liệu khó hiểu! Huhu.

Giới thiệu vấn đề

Cho một dãy số nguyên dương có độ dài \(n\), thực hiện \(q\) truy vấn, mỗi truy vấn yêu cầu tìm giá trị của phần tử nhỏ thứ \(k\) trong đoạn \([l,r]\) của dãy.

Nếu \(n,q \le 10^3\), đây chỉ là bài toán đơn giản cần duyệt toàn bộ; nhưng khi \(n,q \le 10^5\), phương pháp duyệt sẽ không còn hiệu quả, vậy ta phải làm thế nào?

Ta có thể cân nhắc việc lưu trữ các phiên bản lịch sử trong quá trình xây dựng cây, để thuận tiện cho việc tìm kiếm phần tử nhỏ thứ \(k\). Đây chính là nền tảng cốt lõi của cây Chủ tịch.

Cây Chủ tịch là gì?

Cây Chủ tịch, tên đầy đủ là cây đoạn có trọng số dạng bền vững, là một loại cây đoạn bền vững. Đây là một cấu trúc dữ liệu hỗ trợ truy vấn các phiên bản lịch sử. Nếu bạn thực hiện nhiều lần chỉnh sửa trên một mảng, cây Chủ tịch cho phép bạn truy vấn hiệu quả như "giá trị tổng của đoạn sau lần sửa thứ \(k\)" mà không cần lưu toàn bộ dữ liệu cho mỗi lần thay đổi.

Tại sao gọi là cây Chủ tịch?

Cây Chủ tịch được Huang Jiatong độc lập đề xuất và phổ biến vào năm 2010. Vì chữ cái đầu tiên của tên ông (HJT) trùng với tên viết tắt của một vị chủ tịch Trung Quốc thời điểm đó, nên thuật toán này được gọi là "cây Chủ tịch".

Cách hoạt động của cây Chủ tịch?

Nền tảng cốt lõi của cây Chủ tịch đã được đề cập ở phần giới thiệu:

Xem xét việc lưu trữ các phiên bản lịch sử trong quá trình xây dựng cây, để thuận tiện cho việc tìm kiếm phần tử nhỏ thứ \(k\).

Nhưng làm thế nào để lưu trữ chúng?

Một cách đơn giản và mạnh mẽ: tạo riêng từng cây đoạn - nhưng như vậy không gian sẽ bị tràn??

Phân tích kỹ hơn, ta nhận thấy rằng số lượng điểm bị sửa đổi trong mỗi thao tác là giống nhau! Chỉ có \(O(\log n)\) nút bị thay đổi, tạo thành một đường từ gốc đến lá, nghĩa là số lượng nút bị sửa đổi trong mỗi lần cập nhật thực chất bằng chiều cao của cây.

Dưới đây là hình ảnh minh họa từ OI-wiki:

Lưu ý rằng cây Chủ tịch không thể sử dụng phương pháp lưu trữ theo cấp như \(2x\) và \(2x+1\) để biểu diễn con trái và con phải, mà cần mở điểm động và lưu trữ các chỉ số con trái, con phải của mỗi nút.

Vì vậy, chỉ cần lưu trữ thêm gốc của mỗi số được chèn vào là ta đã có thể đạt được tính bền vững!

Tuy nhiên điều này vẫn chưa đủ. Hãy đơn giản hóa vấn đề, thay vì tìm phần tử nhỏ thứ \(k\) trong đoạn \([l,r]\), hãy tìm trong đoạn \([1,r]\). Việc này rất dễ thực hiện, chỉ cần xác định phiên bản gốc tại thời điểm chèn \(r\), rồi dùng cây đoạn có trọng số thông thường.

Làm thế nào để xác định gốc tại thời điểm chèn \(r\)? Đơn giản, đối với mỗi nút \(x\), duy trì gốc \(Root_x\) trong quá trình chèn.

Trở lại với bài toán ban đầu, làm thế nào để giải quyết phần tử nhỏ thứ \(k\) không phải tiền tố? Ở đây ta chỉ cần áp dụng tổng tiền tố - bản chất là tận dụng tính chất trừ đoạn của phép cộng, qua đó đạt được độ phức tạp \(O(1)\) cho mỗi truy vấn. Và trong trường hợp này, thông tin được đếm bởi cây Chủ tịch cũng thỏa mãn tính chất này! Do đó, nếu muốn lấy kết quả từ đoạn \([l,r]\), chỉ cần lấy kết quả từ đoạn \([1,r]\) trừ đi kết quả từ đoạn \([1,l-1]\) là xong > <

Cuối cùng hãy tính toán không gian. Vì mở điểm động, một cây tối đa có \(2n-1\) nút; \(n\) lần sửa đổi, mỗi lần tăng thêm tối đa \(\log_2 n +1\) nút; với \(n \le 10^5\), ước lượng sơ bộ khoảng \(2 \times 10^6\), hoàn toàn ổn!

Triển khai truy vấn tĩnh phần tử nhỏ thứ \(k\)

#include<bits/stdc++.h>
#define ll long long
#define pll pair<ll,ll>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int MAXN = 2e5+10;
const int MAXK = (MAXN<<5);

int n, q, arr[MAXN], sorted[MAXN];
int total_nodes, roots[MAXN], left_child[MAXK], right_child[MAXK], count_val[MAXK];

int read_int() {
    int result = 0, sign = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') sign = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { result = result * 10 + c - '0'; c = getchar(); }
    return result * sign;
}

void build_tree(int &node, int start, int end) {
    node = (++total_nodes);
    if(start == end) return;
    int mid = (start + end) >> 1;
    build_tree(left_child[node], start, mid);
    build_tree(right_child[node], mid + 1, end);
}

int insert_value(int old_node, int start, int end, int pos) {
    int new_node = (++total_nodes);
    int current = new_node, prev = old_node;
    left_child[current] = left_child[prev];
    right_child[current] = right_child[prev];
    count_val[current] = count_val[prev] + 1;
    
    if(start == end) return new_node;
    
    int mid = (start + end) >> 1;
    if(pos <= mid) {
        left_child[current] = insert_value(left_child[prev], start, mid, pos);
    } else {
        right_child[current] = insert_value(right_child[prev], mid + 1, end, pos);
    }
    return new_node;
}

int find_kth(int left_root, int right_root, int start, int end, int k) {
    if(start == end) return start;
    
    int mid = (start + end) >> 1;
    int left_diff = count_val[left_child[right_root]] - count_val[left_child[left_root]];
    
    if(k <= left_diff) {
        return find_kth(left_child[left_root], left_child[right_root], start, mid, k);
    } else {
        return find_kth(right_child[left_root], right_child[right_root], mid + 1, end, k - left_diff);
    }
}

int main() {
    n = read_int();
    q = read_int();
    
    for(int i = 1; i <= n; i++) {
        arr[i] = read_int();
        sorted[i] = arr[i];
    }
    
    sort(sorted + 1, sorted + n + 1);
    int unique_size = unique(sorted + 1, sorted + n + 1) - sorted - 1;
    
    build_tree(roots[0], 1, unique_size);
    
    for(int i = 1; i <= n; i++) {
        int compressed_pos = lower_bound(sorted + 1, sorted + unique_size + 1, arr[i]) - sorted;
        roots[i] = insert_value(roots[i-1], 1, unique_size, compressed_pos);
    }
    
    while(q--) {
        int l = read_int(), r = read_int(), k = read_int();
        int index = find_kth(roots[l-1], roots[r], 1, unique_size, k);
        cout << sorted[index] << "\n";
    }
    
    return 0;
}

Ứng dụng khác

Truy vấn tĩnh số lượng phần tử nhỏ hơn \(k\) trong đoạn

Vì bản chất tương tự như phần trước [Truy vấn tĩnh phần tử nhỏ thứ \(k\)], nên không mô tả chi tiết, trực tiếp đưa mã nguồn.

#include<bits/stdc++.h>
#define ll long long
#define pll pair<ll,ll>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int MAXN = 1e5+10;
const int MAXK = (MAXN<<5);

struct Query { int left, right, limit; } queries[MAXN];

int test_cases, n, m, q, values[MAXN], compressed[MAXN];
int total_nodes, counter;
int roots[MAXN], left_child[MAXK], right_child[MAXK], count_val[MAXK];
map<int,int> mapping;

int read_int() {
    int result = 0, sign = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') sign = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { result = result * 10 + c - '0'; c = getchar(); }
    return result * sign;
}

void reset_data() {
    for(int i = 1; i <= total_nodes; i++)
        left_child[i] = right_child[i] = count_val[i] = 0;
    for(int i = 1; i <= n; i++) roots[i] = 0;
    total_nodes = counter = m = 0;
    mapping.clear();
}

void build_tree(int &node, int start, int end) {
    node = (++total_nodes);
    if(start == end) return;
    int mid = (start + end) >> 1;
    build_tree(left_child[node], start, mid);
    build_tree(right_child[node], mid + 1, end);
}

void update_tree(int &new_node, int old_node, int start, int end, int pos) {
    new_node = (++total_nodes);
    left_child[new_node] = left_child[old_node];
    right_child[new_node] = right_child[old_node];
    count_val[new_node] = count_val[old_node] + 1;
    
    if(start == end) return;
    
    int mid = (start + end) >> 1;
    if(pos <= mid) {
        update_tree(left_child[new_node], left_child[old_node], start, mid, pos);
    } else {
        update_tree(right_child[new_node], right_child[old_node], mid + 1, end, pos);
    }
}

int query_count(int left_root, int right_root, int start, int end, int threshold) {
    if(start == end) return count_val[right_root] - count_val[left_root];
    
    int mid = (start + end) >> 1;
    if(mid >= threshold) {
        return query_count(left_child[left_root], left_child[right_root], start, mid, threshold);
    } else {
        return query_count(right_child[left_root], right_child[right_root], mid + 1, end, threshold) + 
               count_val[left_child[right_root]] - count_val[left_child[left_root]];
    }
}

int main() {
    test_cases = read_int();
    for(int tc = 1; tc <= test_cases; tc++) {
        n = read_int();
        q = read_int();
        reset_data();
        
        for(int i = 1; i <= n; i++) {
            compressed[i] = read_int();
            mapping[compressed[i]] = 0;
        }
        
        for(int i = 1; i <= q; i++) {
            queries[i].left = read_int() + 1;
            queries[i].right = read_int() + 1;
            queries[i].limit = read_int();
            mapping[queries[i].limit] = 0;
        }
        
        for(auto &entry : mapping) entry.second = (++counter);
        
        for(int i = 1; i <= n; i++) compressed[i] = mapping[compressed[i]];
        for(int i = 1; i <= q; i++) queries[i].limit = mapping[queries[i].limit];
        
        cout << "Case " << tc << ":\n";
        
        for(int i = 1; i <= n; i++) {
            update_tree(roots[i], roots[i-1], 1, counter, compressed[i]);
        }
        
        for(int i = 1; i <= q; i++) {
            auto [left, right, limit] = queries[i];
            int result = query_count(roots[left-1], roots[right], 1, counter, limit);
            cout << result << "\n";
        }
    }
    return 0;
}

Cập nhật điểm đơn và tìm phần tử nhỏ thứ \(k\)

Đây là dạng quen thuộc "BIT kết hợp với cây Chủ tịch". Vì cần hỗ trợ sửa đổi và là cập nhật điểm đơn, điều này phù hợp với lĩnh vực mạnh mẽ nhất của BIT! Ta chỉ cần kết hợp BIT vào, lưu trữ các gốc của "chuỗi" cần thay đổi vào một mảng tạm, khi sửa đổi đi sâu vào từng nút lá cần thay đổi, điều chỉnh thông tin đúng cách rồi truyền ngược lên là xong > < Tổng thể vẫn rất đơn giản!

#include<bits/stdc++.h>
#define ll long long
#define pll pair<ll,ll>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int MAXN = 2e5+10;
const int MAXK = MAXN*400;

struct Query { int type, left, right, value; } queries[MAXN];

struct SegmentNode {
    int sum, left_child, right_child;
} nodes[MAXK];

int n, m, q, length, arr[MAXN], mapped_values[MAXN];
int total_nodes, roots[MAXN], temp_roots[2][25], temp_counts[2];

int read_int() {
    int result = 0, sign = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') sign = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { result = result * 10 + c - '0'; c = getchar(); }
    return result * sign;
}

void update_segment(int &node, int start, int end, int pos, int delta) {
    if(!node) node = (++total_nodes);
    nodes[node].sum += delta;
    if(start == end) return;
    
    int mid = (start + end) >> 1;
    if(pos <= mid) {
        update_segment(nodes[node].left_child, start, mid, pos, delta);
    } else {
        update_segment(nodes[node].right_child, mid + 1, end, pos, delta);
    }
}

void bit_update(int idx, int val) {
    int pos = lower_bound(mapped_values + 1, mapped_values + length + 1, arr[idx]) - mapped_values;
    for(; idx <= n; idx += idx & (-idx)) {
        update_segment(roots[idx], 1, length, pos, val);
    }
}

int query_kth_impl(int start, int end, int k) {
    if(start == end) return start;
    
    int mid = (start + end) >> 1, current_count = 0;
    
    for(int i = 1; i <= temp_counts[1]; i++) {
        current_count += nodes[nodes[temp_roots[1][i]].left_child].sum;
    }
    for(int i = 1; i <= temp_counts[0]; i++) {
        current_count -= nodes[nodes[temp_roots[0][i]].left_child].sum;
    }
    
    if(k <= current_count) {
        for(int i = 1; i <= temp_counts[1]; i++) temp_roots[1][i] = nodes[temp_roots[1][i]].left_child;
        for(int i = 1; i <= temp_counts[0]; i++) temp_roots[0][i] = nodes[temp_roots[0][i]].left_child;
        return query_kth_impl(start, mid, k);
    } else {
        for(int i = 1; i <= temp_counts[1]; i++) temp_roots[1][i] = nodes[temp_roots[1][i]].right_child;
        for(int i = 1; i <= temp_counts[0]; i++) temp_roots[0][i] = nodes[temp_roots[0][i]].right_child;
        return query_kth_impl(mid + 1, end, k - current_count);
    }
}

int query_kth(int left, int right, int k) {
    memset(temp_roots, 0, sizeof(temp_roots));
    temp_counts[0] = temp_counts[1] = 0;
    
    for(int i = right; i; i -= i & (-i)) temp_roots[1][++temp_counts[1]] = roots[i];
    for(int i = left - 1; i; i -= i & (-i)) temp_roots[0][++temp_counts[0]] = roots[i];
    
    return query_kth_impl(1, length, k);
}

int main() {
    n = read_int();
    q = read_int();
    
    for(int i = 1; i <= n; i++) {
        arr[i] = read_int();
        mapped_values[++length] = arr[i];
    }
    
    for(int i = 1; i <= q; i++) {
        char op_type;
        cin >> op_type;
        if(op_type == 'Q') {
            queries[i].type = 1;
            queries[i].left = read_int();
            queries[i].right = read_int();
            queries[i].value = read_int();
        } else {
            queries[i].type = 0;
            queries[i].left = read_int();
            queries[i].value = read_int();
            mapped_values[++length] = queries[i].value;
        }
    }
    
    sort(mapped_values + 1, mapped_values + length + 1);
    length = unique(mapped_values + 1, mapped_values + length + 1) - mapped_values - 1;
    
    for(int i = 1; i <= n; i++) bit_update(i, 1);
    
    for(int i = 1; i <= q; i++) {
        auto [type, left, right, value] = queries[i];
        if(type) {
            cout << mapped_values[query_kth(left, right, value)] << "\n";
        } else {
            bit_update(left, -1);
            arr[left] = value;
            bit_update(left, 1);
        }
    }
    
    return 0;
}

Tìm trung vị

Tìm trung vị... thực ra không khó, bạn chỉ cần biết độ dài đoạn rồi áp dụng cách thức của [truy vấn phần tử nhỏ thứ \(k\) trong đoạn]. Nhưng có một bài toán thú vị kết hợp tìm nhị phân, tiền tố và hậu tố \(\max\), tìm trung vị tốt nhất trong đoạn cố định, là một bài toán rất hay.

Đưa mã nguồn.

#include<bits/stdc++.h>
#define ll long long
#define pll pair<ll,ll>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int MAXN = 3e4+5;
const int MAXK = (MAXN<<6);

struct TreeNode {
    int left_child, right_child, sum, left_max, right_max;
} nodes[MAXK];

struct Query { int left, right, value; } queries[MAXN];

int test_cases, n, m, q, pos[MAXN], arr[MAXN], answer;
int num_counter, group[MAXN], total_nodes, roots[MAXN];
map<int,int> mapping;
vector<int> numbers[MAXN];

int read_int() {
    int result = 0, sign = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') sign = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { result = result * 10 + c - '0'; c = getchar(); }
    return result * sign;
}

void push_up(int node) {
    nodes[node].left_max = max(nodes[nodes[node].left_child].left_max, 
                              nodes[nodes[node].left_child].sum + nodes[nodes[node].right_child].left_max);
    nodes[node].right_max = max(nodes[nodes[node].right_child].right_max, 
                               nodes[nodes[node].right_child].sum + nodes[nodes[node].left_child].right_max);
    nodes[node].sum = nodes[nodes[node].left_child].sum + nodes[nodes[node].right_child].sum;
}

void update_node(int &node, int start, int end, int pos, int value) {
    nodes[++total_nodes] = nodes[node];
    node = total_nodes;
    if(start > pos || end < pos) return;
    if(start == end && pos == start) {
        nodes[node].sum = value;
        if(value > 0) {
            nodes[node].left_max = value;
            nodes[node].right_max = value;
        }
        return;
    }
    int mid = (start + end) >> 1;
    update_node(nodes[node].left_child, start, mid, pos, value);
    update_node(nodes[node].right_child, mid + 1, end, pos, value);
    push_up(node);
}

int query_sum(int node, int start, int end, int left, int right) {
    if(end < left || right < start) return 0;
    if(left <= start && end <= right) return nodes[node].sum;
    int mid = (start + end) >> 1, result = 0;
    result += query_sum(nodes[node].left_child, start, mid, left, right);
    result += query_sum(nodes[node].right_child, mid + 1, end, left, right);
    return result;
}

int query_left_max(int node, int start, int end, int left, int right) {
    if(end < left || right < start) return 0;
    if(left <= start && end <= right) return nodes[node].left_max;
    int mid = (start + end) >> 1, result = 0;
    result = max(result, query_left_max(nodes[node].left_child, start, mid, left, right));
    result = max(result, query_sum(nodes[node].left_child, start, mid, left, right) + 
                 query_left_max(nodes[node].right_child, mid + 1, end, left, right));
    return result;
}

int query_right_max(int node, int start, int end, int left, int right) {
    if(end < left || right < start) return 0;
    if(left <= start && end <= right) return nodes[node].right_max;
    int mid = (start + end) >> 1, result = 0;
    result = max(result, query_right_max(nodes[node].right_child, mid + 1, end, left, right));
    result = max(result, query_sum(nodes[node].right_child, mid + 1, end, left, right) + 
                 query_right_max(nodes[node].left_child, start, mid, left, right));
    return result;
}

int main() {
    n = read_int();
    for(int i = 1; i <= n; i++) {
        arr[i] = read_int();
        mapping[arr[i]] = 0;
    }
    q = read_int();
    
    for(auto &entry : mapping) entry.second = (++num_counter), group[entry.second] = entry.first;
    
    for(int i = 1; i <= n; i++) {
        arr[i] = mapping[arr[i]];
        numbers[arr[i]].pb(i);
    }
    
    for(int i = 1; i <= n; i++) {
        update_node(roots[num_counter + 1], 1, n, i, -1);
    }
    
    for(int i = num_counter; i >= 1; i--) {
        roots[i] = roots[i + 1];
        for(int x : numbers[i]) update_node(roots[i], 1, n, x, 1);
    }
    
    while(q--) {
        int a = read_int(), b = read_int(), c = read_int(), d = read_int();
        int temp_arr[4] = {(a + answer) % n, (b + answer) % n, (c + answer) % n, (d + answer) % n};
        sort(temp_arr, temp_arr + 4);
        a = temp_arr[0] + 1, b = temp_arr[1] + 1, c = temp_arr[2] + 1, d = temp_arr[3] + 1;
        
        int left = 1, right = num_counter, result = 1;
        while(left <= right) {
            int mid = (left + right) >> 1;
            int segment_sum = query_sum(roots[mid], 1, n, b, c);
            int right_max = query_right_max(roots[mid], 1, n, a, b - 1);
            int left_max = query_left_max(roots[mid], 1, n, c + 1, d);
            int ans = segment_sum + left_max + right_max;
            
            if(ans >= 0) {
                result = mid;
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        answer = group[result];
        cout << answer << "\n";
    }
    return 0;
}

Bài tập: [Mẫu] Cây lồng cây

Thực tế không còn là cây Chủ tịch nữa, vì nó không còn là cây đoạn có trọng số, nhưng vẫn là cây đoạn bền vững mở điểm động. Kết hợp với BIT, ý tưởng bài toán rất đơn giản nhưng triển khai phức tạp, vì cần xử lý phần tử nhỏ thứ \(k\), hạng của số \(k\), tiền tố và hậu tố của mỗi số... Logic vẫn đơn giản, cuối cùng đưa mã nguồn!

#include<bits/stdc++.h>
#define ll long long
#define pll pair<ll,ll>
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int MAXN = 5e4+5;
const int MAXM = MAXN*150;

struct Query { int type, left, right, pos, value; } queries[MAXN];

struct TreeNode {
    int sum, left_child, right_child;
} nodes[MAXM];

int n, q, arr[MAXN], num_counter, back_ref[2*MAXN];
int total_nodes, roots[MAXN], temp[2][MAXN];
map<int,int> mapping;

int read_int() {
    int result = 0, sign = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') sign = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { result = result * 10 + c - '0'; c = getchar(); }
    return result * sign;
}

void push_up(int node) {
    nodes[node].sum = nodes[nodes[node].left_child].sum + nodes[nodes[node].right_child].sum;
}

void modify_node(int &node, int start, int end, int pos, int delta) {
    if(!node) node = (++total_nodes);
    if(start == end && start == pos) {
        nodes[node].sum += delta;
        return;
    }
    int mid = (start + end) >> 1;
    if(pos <= mid) {
        modify_node(nodes[node].left_child, start, mid, pos, delta);
    } else {
        modify_node(nodes[node].right_child, mid + 1, end, pos, delta);
    }
    push_up(node);
}

void bit_modify(int idx, int val) {
    int x = arr[idx];
    while(idx <= n) {
        modify_node(roots[idx], 1, num_counter, x, val);
        idx += idx & (-idx);
    }
}

int find_number(int left, int right, int k) {
    if(left == right) return left;
    int mid = (left + right) >> 1, current = 0;
    
    for(int i = 1; i <= temp[0][0]; i++) current += nodes[nodes[temp[0][i]].left_child].sum;
    for(int i = 1; i <= temp[1][0]; i++) current -= nodes[nodes[temp[1][i]].left_child].sum;
    
    if(k <= current) {
        for(int i = 1; i <= temp[0][0]; i++) temp[0][i] = nodes[temp[0][i]].left_child;
        for(int i = 1; i <= temp[1][0]; i++) temp[1][i] = nodes[temp[1][i]].left_child;
        return find_number(left, mid, k);
    } else {
        for(int i = 1; i <= temp[0][0]; i++) temp[0][i] = nodes[temp[0][i]].right_child;
        for(int i = 1; i <= temp[1][0]; i++) temp[1][i] = nodes[temp[1][i]].right_child;
        return find_number(mid + 1, right, k - current);
    }
}

int query_number(int left, int right, int k) {
    temp[0][0] = 0;
    temp[1][0] = 0;
    left--;
    
    while(right) {
        temp[0][++temp[0][0]] = roots[right];
        right -= right & (-right);
    }
    while(left) {
        temp[1][++temp[1][0]] = roots[left];
        left -= left & (-left);
    }
    return find_number(1, num_counter, k);
}

int find_rank(int left, int right, int k) {
    if(left == right) return 0;
    int mid = (left + right) >> 1;
    if(k <= mid) {
        for(int i = 1; i <= temp[0][0]; i++) temp[0][i] = nodes[temp[0][i]].left_child;
        for(int i = 1; i <= temp[1][0]; i++) temp[1][i] = nodes[temp[1][i]].left_child;
        return find_rank(left, mid, k);
    } else {
        int current = 0;
        for(int i = 1; i <= temp[0][0]; i++) {
            current += nodes[nodes[temp[0][i]].left_child].sum;
            temp[0][i] = nodes[temp[0][i]].right_child;
        }
        for(int i = 1; i <= temp[1][0]; i++) {
            current -= nodes[nodes[temp[1][i]].left_child].sum;
            temp[1][i] = nodes[temp[1][i]].right_child;
        }
        return current + find_rank(mid + 1, right, k);
    }
}

int query_rank(int left, int right, int k) {
    temp[0][0] = 0;
    temp[1][0] = 0;
    left--;
    
    while(right) {
        temp[0][++temp[0][0]] = roots[right];
        right -= right & (-right);
    }
    while(left) {
        temp[1][++temp[1][0]] = roots[left];
        left -= left & (-left);
    }
    return find_rank(1, num_counter, k) + 1;
}

int find_prev(int left, int right, int k) {
    int rank = query_rank(left, right, k) - 1;
    if(!rank) return 0;
    else return query_number(left, right, rank);
}

int find_next(int left, int right, int k) {
    if(k == num_counter) return num_counter + 1;
    int rank = query_rank(left, right, k + 1);
    if(rank == right - left + 2) return num_counter + 1;
    else return query_number(left, right, rank);
}

int main() {
    n = read_int();
    q = read_int();
    
    for(int i = 1; i <= n; i++) {
        arr[i] = read_int();
        mapping[arr[i]] = 0;
    }
    
    for(int i = 1; i <= q; i++) {
        queries[i].type = read_int();
        if(queries[i].type == 3) {
            queries[i].pos = read_int();
            queries[i].value = read_int();
        } else {
            queries[i].left = read_int();
            queries[i].right = read_int();
            queries[i].value = read_int();
        }
        if(queries[i].type != 2) mapping[queries[i].value] = 0;
    }
    
    for(auto &entry : mapping) {
        entry.second = (++num_counter);
        back_ref[entry.second] = entry.first;
    }
    
    back_ref[0] = -2147483647;
    back_ref[num_counter + 1] = 2147483647;
    
    for(int i = 1; i <= n; i++) {
        arr[i] = mapping[arr[i]];
        bit_modify(i, 1);
    }
    
    for(int i = 1; i <= q; i++) {
        if(queries[i].type != 2) queries[i].value = mapping[queries[i].value];
    }
    
    for(int i = 1; i <= q; i++) {
        if(queries[i].type == 1) {
            cout << query_rank(queries[i].left, queries[i].right, queries[i].value) << "\n";
        } else if(queries[i].type == 2) {
            cout << back_ref[query_number(queries[i].left, queries[i].right, queries[i].value)] << "\n";
        } else if(queries[i].type == 3) {
            bit_modify(queries[i].pos, -1);
            arr[queries[i].pos] = queries[i].value;
            bit_modify(queries[i].pos, 1);
        } else if(queries[i].type == 4) {
            cout << back_ref[find_prev(queries[i].left, queries[i].right, queries[i].value)] << "\n";
        } else {
            cout << back_ref[find_next(queries[i].left, queries[i].right, queries[i].value)] << "\n";
        }
    }
    
    return 0;
}

Ngoài ra, bài này có vẻ như có thể giải bằng phương pháp chia đoạn.

Tóm tắt và kết luận

Cây Chủ tịch, tức cây đoạn có trọng số dạng bền vững, là công cụ thuận tiện để xử lý các truy vấn phần tử nhỏ thứ \(k\) trong đoạn tĩnh, tiền tố và hậu tố. Tư tưởng cốt lõi nằm ở lưu trữ phiên bản lịch sử, tránh tràn không gian thông qua mở điểm động, và tận dụng ý tưởng tổng tiền tố để thực hiện truy vấn đoạn bất kỳ. Từ truy vấn tĩnh đến cập nhật điểm đơn (BIT kết hợp cây Chủ tịch), đến các thao tác phức tạp như trung vị, hạng, cây Chủ tịch thể hiện khả năng mở rộng mạnh mẽ. Dù mã nguồn dài, nhưng có thể giải quyết gọn gàng nhiều bài toán khó, là một công cụ tuyệt vời!

Cảm ơn đã đọc.

Thẻ: Data-Structures persistent-segment-tree competitive-programming algorithm-design tree-data-structure

Đăng vào ngày 19 tháng 5 lúc 23:26