:heavy_check_mark: 二次元セグメント木 (DataStructure/2d_segment_tree.hpp)

二次元セグメント木

2 次元グリッド(高さ H、幅 W)上の値に対して、1点更新と長方形領域のモノイド積(集約)を行うデータ構造です。

  • set(h, w, x):座標 (h, w) の値を x に更新(0-indexed)
  • get(h, w):座標 (h, w) の値を取得
  • prod(t, l, d, r):長方形 [t, d) × [l, r) のモノイド積を取得(半開区間)

ope により定義されるモノイド(結合的演算 + 単位元)に対して動作します。

コンストラクタ

サイズ指定で初期化

segtree2d(int h, int w)

高さ h、幅 w の 2D 配列を、すべて e() で初期化します。

制約

  • $0 \leq h$
  • $0 \leq w$

計算量

  • $O(HW)$

2 次元配列から構築

segtree2d(const vector<vector<S>>& v)

v(サイズ H × W)から構築します。

制約

  • v.size() == H
  • v[i].size() == W がすべての i で成り立つこと(長方形配列)

計算量

  • $O(HW)$

set

void set(int p_h, int p_w, S x)

座標 (p_h, p_w)(0-indexed)を x に更新します。

制約

  • $0 \leq p_h < H$
  • $0 \leq p_w < W$

計算量

  • $O(\log H \log W)$

get

S get(int p_h, int p_w) const

座標 (p_h, p_w)(0-indexed)の値を返します。

制約

  • $0 \leq p_h < H$
  • $0 \leq p_w < W$

計算量

  • $O(1)$

prod

S prod(int t, int l, int d, int r) const

長方形領域 [t, d) × [l, r)(0-indexed, 半開区間)のモノイド積を返します。
領域が空(t == d または l == r)の場合は e() を返します。

制約

  • $0 \leq t \leq d \leq H$
  • $0 \leq l \leq r \leq W$

計算量

  • $O(\log H \log W)$

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

template <class S, auto op, auto e>
struct segtree2d
{
    static_assert(is_convertible_v<decltype(op), function<S(S, S)>>,
                  "op must work as S(S, S)");
    static_assert(is_convertible_v<decltype(e), function<S()>>,
                  "e must work as S()");

public:
    segtree2d() : segtree2d(0) {}
    explicit segtree2d(int h, int w)
    {
        init(h, w);
    }

    explicit segtree2d(const vector<vector<S>> &v)
    {
        init(v.size(), v[0].size());

        for (int i = 0; i < _h; i++)
        {
            for (int j = 0; j < _w; j++)
            {
                values[id(size_h + i, size_w + j)] = v[id(i, j)];
            }
        }

        for (int w = size_w; w < size_w * 2; w++)
        {
            for (int h = size_h - 1; h >= 1; h--)
            {
                update_h(h, w);
            }
        }

        for (int h = 0; h < size_h * 2; h++)
        {
            for (int w = size_w - 1; w >= 1; w--)
            {
                update_w(h, w);
            }
        }
    }

    void set(int p_h, int p_w, S x)
    {
        assert(0 <= p_h && p_h < _h);
        assert(0 <= p_w && p_w < _w);

        p_h += size_h;
        p_w += size_w;

        values[id(p_h, p_w)] = x;

        for (int i = p_h >> 1; i; i >>= 1)
        {
            update_h(i, p_w);
        }

        for (; p_h; p_h >>= 1)
        {
            for (int j = p_w >> 1; j; j >>= 1)
            {
                update_w(p_h, j);
            }
        }
    }

    S get(int p_h, int p_w) const
    {
        assert(0 <= p_h && p_h < _h);
        assert(0 <= p_w && p_w < _w);

        return values[id(p_h + size_h, p_w + size_w)];
    }

    S prod(int t, int l, int d, int r) const
    {
        assert(0 <= t && t <= d && t <= _h);
        assert(0 <= l && l <= r && r <= _w);

        S smt = e(), smd = e();
        t += size_h;
        d += size_h;
        l += size_w;
        r += size_w;

        for (; t < d; t >>= 1, d >>= 1)
        {
            if (t & 1)
                smt = op(smt, inner_prod(t++, l, r));
            if (d & 1)
                smd = op(inner_prod(--d, l, r), smd);
        }

        return op(smt, smd);
    }

private:
    int _n, size_h, size_w;
    int _h, _w;
    vector<S> values;

    void init(int h, int w)
    {
        _h = h;
        _w = w;
        size_h = (int)bit_ceil((unsigned int)(_h));
        size_w = (int)bit_ceil((unsigned int)(_w));

        values = vector<S>(size_h * size_w * 4, e());
    }

    int id(int h, int w) const
    {
        return h * 2 * size_w + w;
    }

    void update_h(int h, int w)
    {
        values[id(h, w)] = op(values[id(2 * h, w)], values[id(2 * h + 1, w)]);
    }

    void update_w(int h, int w)
    {
        values[id(h, w)] = op(values[id(h, 2 * w)], values[id(h, 2 * w + 1)]);
    }

    S inner_prod(int h, int l, int r) const
    {
        S sml = e(), smr = e();

        for (; l < r; l >>= 1, r >>= 1)
        {
            if (l & 1)
                sml = op(sml, values[id(h, l++)]);
            if (r & 1)
                smr = op(values[id(h, --r)], smr);
        }

        return op(sml, smr);
    }
};
#line 1 "DataStructure/2d_segment_tree.hpp"
#include <bits/stdc++.h>
using namespace std;

template <class S, auto op, auto e>
struct segtree2d
{
    static_assert(is_convertible_v<decltype(op), function<S(S, S)>>,
                  "op must work as S(S, S)");
    static_assert(is_convertible_v<decltype(e), function<S()>>,
                  "e must work as S()");

public:
    segtree2d() : segtree2d(0) {}
    explicit segtree2d(int h, int w)
    {
        init(h, w);
    }

    explicit segtree2d(const vector<vector<S>> &v)
    {
        init(v.size(), v[0].size());

        for (int i = 0; i < _h; i++)
        {
            for (int j = 0; j < _w; j++)
            {
                values[id(size_h + i, size_w + j)] = v[id(i, j)];
            }
        }

        for (int w = size_w; w < size_w * 2; w++)
        {
            for (int h = size_h - 1; h >= 1; h--)
            {
                update_h(h, w);
            }
        }

        for (int h = 0; h < size_h * 2; h++)
        {
            for (int w = size_w - 1; w >= 1; w--)
            {
                update_w(h, w);
            }
        }
    }

    void set(int p_h, int p_w, S x)
    {
        assert(0 <= p_h && p_h < _h);
        assert(0 <= p_w && p_w < _w);

        p_h += size_h;
        p_w += size_w;

        values[id(p_h, p_w)] = x;

        for (int i = p_h >> 1; i; i >>= 1)
        {
            update_h(i, p_w);
        }

        for (; p_h; p_h >>= 1)
        {
            for (int j = p_w >> 1; j; j >>= 1)
            {
                update_w(p_h, j);
            }
        }
    }

    S get(int p_h, int p_w) const
    {
        assert(0 <= p_h && p_h < _h);
        assert(0 <= p_w && p_w < _w);

        return values[id(p_h + size_h, p_w + size_w)];
    }

    S prod(int t, int l, int d, int r) const
    {
        assert(0 <= t && t <= d && t <= _h);
        assert(0 <= l && l <= r && r <= _w);

        S smt = e(), smd = e();
        t += size_h;
        d += size_h;
        l += size_w;
        r += size_w;

        for (; t < d; t >>= 1, d >>= 1)
        {
            if (t & 1)
                smt = op(smt, inner_prod(t++, l, r));
            if (d & 1)
                smd = op(inner_prod(--d, l, r), smd);
        }

        return op(smt, smd);
    }

private:
    int _n, size_h, size_w;
    int _h, _w;
    vector<S> values;

    void init(int h, int w)
    {
        _h = h;
        _w = w;
        size_h = (int)bit_ceil((unsigned int)(_h));
        size_w = (int)bit_ceil((unsigned int)(_w));

        values = vector<S>(size_h * size_w * 4, e());
    }

    int id(int h, int w) const
    {
        return h * 2 * size_w + w;
    }

    void update_h(int h, int w)
    {
        values[id(h, w)] = op(values[id(2 * h, w)], values[id(2 * h + 1, w)]);
    }

    void update_w(int h, int w)
    {
        values[id(h, w)] = op(values[id(h, 2 * w)], values[id(h, 2 * w + 1)]);
    }

    S inner_prod(int h, int l, int r) const
    {
        S sml = e(), smr = e();

        for (; l < r; l >>= 1, r >>= 1)
        {
            if (l & 1)
                sml = op(sml, values[id(h, l++)]);
            if (r & 1)
                smr = op(values[id(h, --r)], smr);
        }

        return op(sml, smr);
    }
};
Back to top page