Segment Tree & Binary Indexed Tree

Range query when input can be updated(在线修改并查询), otherwise prefixSum should be good enough

  • Segment Tree本质上是一个 Full Binary Tree,每个节点子节点数为0或2

  • 但是每一个segmentTreeNode里存的start(区间左边界), end(区间右边界), val(视题目而定,比如max val)

  • 假设输入数组有N个数,这样Segment Tree就有N个叶子节点,一共有2*N - 1个节点

  • 叶子节点必是长度为1的区间,如[0,0]或[1,1]

  • O(N) time to build tree

  • O(LogN) to query

  • O(LogN) to update

线段树的作用,可以解决的问题

  • 主要问题对象是区间

  • 求区间和,区间最值以及其它区间上的问题

  • 无脑实现3个method

    • buildTree(int start, int end, int[] A)

    • modify(SegmentTreeNode root, int idx, int val)

    • query(SegmentTreeNode root, int start, int end)

问题:给定一个序列,会修改序列上某个位置的数,或是查询区间的和

如果仅涉及区间上的查询,不涉及修改的话,用前缀和即可

修改时间

查询时间

空间

暴力

O(1)

O(N)

O(1)

前缀和

O(N)

O(1)

O(N)

线段树

O(logN)

O(logN)

O(N)

线段数基本操作

  • 最基础的建二叉树,只建左右子树,并没有存max value

public SegmentTreeNode build(int start, int end) {
        // [10, 4] return null
        if(start > end) return null;

        SegmentTreeNode root = new SegmentTreeNode(start, end);
        if(start == end) return root;

        int mid = start + (end - start)/2;

        root.left = build(start, mid);
        root.right = build(mid + 1, end);
        return root;
    }

  • 这题才是常规的build,要求Node里记住max value

public SegmentTreeNode build(int[] A) {
        // write your code here
        if(A.length < 1) return null;
        return buildHelper(A, 0, A.length - 1);

    }
    private SegmentTreeNode buildHelper(int[] A, int start, int end){
        if(start > end) return null;
        if(start == end) return new SegmentTreeNode(start, end, A[start]);

        int mid = start + (end - start) /2 ;
        SegmentTreeNode left = buildHelper(A, start, mid);
        SegmentTreeNode right = buildHelper(A, mid + 1, end);

        SegmentTreeNode root = new SegmentTreeNode(start, end, Math.max(left.max, right.max));

        root.left = left;
        root.right = right;

        return root;
    }

  • 很好的一道dfs的题目

public void modify(SegmentTreeNode root, int index, int value) {
        // write your code here
        if(root == null) return;
        if(index < root.start || index > root.end) return;

        if(index == root.start && index == root.end){
            //不能写成root.max = Math.max(root.max, vlaue)
            //这个时候必须改变值
            root.max = value;
            return;
        }
        modify(root.left, index, value);
        modify(root.right, index, value);
        root.max = Math.max(root.left.max, root.right.max);
    }

  • 要求区间内的max value

  • 对于每一层的查询,只有两种可能

    • 如果目标区间完全不在root区间内,直接返回(此题返回Int.min)

    • 否则设下一层的有效查询区间为:

      • start = max(query.start, root.start);

      • end = max(query.end, root.end);

    • 然后处理好正确的节点位置,递归查询

public int query(SegmentTreeNode root, int start, int end) {
        // write your code here

        if(start > root.end || end < root.start) return Integer.MIN_VALUE;

        start = Math.max(start, root.start);
        end = Math.min(end, root.end);

        if(start == root.start && end == root.end) return root.max;

        int leftMax = query(root.left, start, end);
        int rightMax = query(root.right, start, end);
        return Math.max(leftMax, rightMax);
    }

  • 要求区间内的value count

  • 注意root可能为null

public int query(SegmentTreeNode root, int start, int end) {
        // write your code here
        if(root == null) return 0;
        if(start > root.end || end < root.start) return 0;

        start = Math.max(start, root.start);
        end = Math.min(end, root.end);

        if(start == root.start && end == root.end) return root.count;

        int leftCnt = query(root.left, start, end);
        int rightCnt = query(root.right, start, end);

        return leftCnt + rightCnt;
    }

线段数应用

Range Sum Query

  • 线段树经典应用

  • 熟悉基本操作,模板要熟练

class NumArray {
    public class segmentTreeNode {
        int sum;
        segmentTreeNode left;
        segmentTreeNode right;
        int begin;
        int end;
        public segmentTreeNode(int begin, int end) {
            this.begin = begin;
            this.end = end;
            left = right = null;
            sum = 0;
        }
    }

    //build
    segmentTreeNode build(int[] arr, int begin, int end) {
        if (begin > end) {
            return null;
        }
        segmentTreeNode root = new segmentTreeNode(begin, end);

        if (begin == end) {
            root.sum = arr[begin];
            return root;
        }
        int mid = (end - begin) / 2 + begin;
        root.left = build(arr, begin, mid);
        root.right = build(arr, mid + 1, end);
        root.sum = root.left.sum + root.right.sum;
        return root;
    }

    //modify
    public void modify(segmentTreeNode root, int idx, int val) {
        if (root.end == idx && root.begin == idx) {
            root.sum = val;
            return;
        }
        int mid = (root.end - root.begin) / 2 + root.begin;
        if (root.begin <= idx && idx <= mid) {
            modify(root.left, idx, val);
        }
        if (mid < idx && idx <= root.end) {
            modify(root.right, idx, val);
        }
        root.sum = root.left.sum + root.right.sum;

    }

    //query
    public int query(segmentTreeNode root, int begin, int end) {
        if (root.begin > end || root.end < begin) {
            return 0;
        }
        if (begin <= root.begin && end >= root.end) {
            return root.sum;
        }
        return query(root.left, begin, end) + query(root.right, begin, end);
    }

    segmentTreeNode node;
    int[] arr;
    public NumArray(int[] nums) {
       // node = new segmentTreeNode(0, nums.length - 1);
        arr = nums;
        node = build(arr, 0, nums.length - 1);
    }

    public void update(int i, int val) {
        modify(node, i, val);
    }

    public int sumRange(int i, int j) {
        return query(node, i, j);
    }
}

  • 线段数的基本应用

  • 第一遍做的时候有个typo,在queryMin method 里错写成end = Math.min(end, root.end);

  • 20min码完,码字速度着实要加强

public class Solution {
    /**
     *@param A, queries: Given an integer array and an query list
     *@return: The result list
     */
    public class segmentTreeNode{
        int start;
        int end;
        int min;
        segmentTreeNode left;
        segmentTreeNode right;
        public segmentTreeNode(int start, int end, int min){
            this.start = start;
            this.end = end;
            this.min = min;
        }
    }

    public segmentTreeNode build(int[] A, int start, int end){
        if(start < 0 || end >= A.length || start > end) return null;
        if(start == end) {
            return new segmentTreeNode(start, end, A[start]);
        }
        int mid = start + (end - start) / 2;
        segmentTreeNode left = build(A, start, mid);
        segmentTreeNode right = build(A, mid + 1, end);
        segmentTreeNode root = new segmentTreeNode(start, end, Math.min(left.min, right.min));
        root.left = left;
        root.right = right;
        return root;
    }

    public int queryMin(segmentTreeNode root, int start, int end){
        if(start > root.end || end < root.start) return Integer.MAX_VALUE;
        start = Math.max(start, root.start);
        end = Math.min(end, root.end);
        if(start == root.start && end == root.end) return root.min;

        int left = queryMin(root.left, start, end);
        int right = queryMin(root.right, start, end);
        return Math.min(left, right);
    }

    public ArrayList<Integer> intervalMinNumber(int[] A, 
                                                ArrayList<Interval> queries) {
        // write your code here
        ArrayList<Integer> res = new ArrayList<>();
        if(A == null || A.length < 1 || queries == null || queries.size() < 1) return res;

        segmentTreeNode root = build(A, 0, A.length - 1);

        for(Interval interval : queries){
            int start = interval.start;
            int end = interval.end;
            int min = queryMin(root, start, end);
            res.add(min);
        }
        return res;
    }
}

Last updated

Was this helpful?