BiruLyu
7/24/2017 - 4:50 AM

327. Count of Range Sum(#).java

public class Solution {
    class TreeNode {
        long val;
        int leftSize, rightSize, count;
        TreeNode left, right;
        public TreeNode(long v) {
            this.val = v;
            count = 1;
        }
    }
    public int countRangeSum(int[] nums, int lower, int upper) {
        long[] sums = new long[nums.length + 1];
        for(int i = 0; i < nums.length; i++) {
            sums[i + 1] = sums[i] + nums[i];
        }
        TreeNode root = new TreeNode(sums[0]);
        int output = 0;
        for(int i = 1; i < sums.length; i++) {
            output += rangeSize(root, sums[i] - upper, sums[i] - lower);
            insert(root, sums[i]);
        }
        return output;
    }
    private TreeNode insert(TreeNode root, long val) {
        if (root == null) return new TreeNode(val);
        else if (root.val == val) root.count++;
        else if (root.val < val) {
            root.rightSize++;
            root.right = insert(root.right, val);
        } else if (root.val > val) {
            root.leftSize++;
            root.left = insert(root.left, val);
        }
        return root;
    }
    private int rangeSize(TreeNode root, long lower, long upper) {
        int total = root.count + root.leftSize + root.rightSize;
        int smaller = getSmaller(root, lower);
        int larger = getLarger(root, upper);
        return total - smaller - larger;
    }
    private int getSmaller(TreeNode root, long lower) {
        if (root == null) return 0;
        if (root.val == lower) return root.leftSize;
        if (root.val > lower) return getSmaller(root.left, lower);
        return root.leftSize + root.count + getSmaller(root.right, lower);
    }
    private int getLarger(TreeNode root, long upper) {
        if (root == null) return 0;
        if (root.val == upper) return root.rightSize;
        if (root.val < upper) return getLarger(root.right, upper);
        return root.rightSize + root.count + getLarger(root.left, upper);
    }
}
public class Solution {
    long[] sum;
    int mergeSort(int l,int r,int a,int b)
    {
        if (l>=r) return 0;
        int mid=(l+r)>>1;
        int ans=mergeSort(l,mid,a,b)+mergeSort(mid+1,r,a,b);
        long[] cache=new long[r-l+1];
        int j=mid+1,k=mid+1,t=mid+1,ii=0;
        for (int i=l;i<=mid;i++,ii++)
        {
            while (j<=r && sum[j]-sum[i]<a) j++;
            while (k<=r && sum[k]-sum[i]<=b) k++;
            ans+=k-j;
            while (t<=r && sum[t]<sum[i]) cache[ii++]=sum[t++];
            cache[ii]=sum[i];
        }
        for (int i=0;i<ii;i++) sum[i+l]=cache[i];
        return ans;
    }
    public int countRangeSum(int[] nums, int lower, int upper) {
        int n=nums.length;
        sum=new long[n+1];
        for (int i=1;i<=n;i++) sum[i]=sum[i-1]+nums[i-1];
        return mergeSort(0,n,lower,upper);
    }
}
public class Solution {
    long[] counts;
    int lower,upper;
    public int countRangeSum(int[] nums, int lower, int upper) {
        int length = nums.length;
        this.lower = lower;this.upper = upper;
        if(length <= 0)
            return 0;
        counts = new long[nums.length];
        counts[0] = nums[0];
        for(int i = 1;i<nums.length;i++){
            counts[i] = counts[i-1]+nums[i];
        }

        return countNum(nums,0,length-1);
    }
    private int countNum(int[] nums,int left,int right){
        if(left == right){
            if(nums[left] >=lower && nums[right] <= upper)
                return 1;
            return 0;
        }
        int mid = (left+right)/2;
        int total = 0;
        for(int i = left;i<=mid;i++) {
            for(int j = mid+1;j<=right;j++) {
                long tmpNum = counts[j] - counts[i] + nums[i];
                if(tmpNum >= lower && tmpNum <= upper)
                    ++total;
            }
        }
        //采用二分法
        return total + countNum(nums,left,mid) + countNum(nums,mid+1,right);
    }
}