BiruLyu
6/24/2017 - 5:42 AM

508. Most Frequent Subtree Sum(1st).java

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
public class Solution {
    public int[] findFrequentTreeSum(TreeNode root) {
        Map<Integer, Integer> map = new HashMap<Integer, Integer>();
        dfs(root, map);
        int max = 0;
        List<Integer> temp = new ArrayList<Integer>();
        for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
            int count = entry.getValue();
            if (count > max) {
                max = count;
                temp.clear();
                temp.add(entry.getKey());
            } else if (count == max) {
                temp.add(entry.getKey());
            }
        }
        int size = temp.size();
        int[] res = new int[size];
        for (int i = 0; i < size; i++) {
            res[i] = temp.get(i);
        }
        
        return res;
    }
    
    
    public int dfs(TreeNode root, Map<Integer, Integer> map) {
        if (root == null) return 0;
        int treeSum = dfs(root.left, map) + root.val + dfs(root.right, map);
        map.put(treeSum, map.containsKey(treeSum) ? map.get(treeSum) + 1 : 1);
        return treeSum;
    }
    
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
public class Solution {
    int max;
	public int[] findFrequentTreeSum(TreeNode root) {
		Map<Integer, Integer> map = new HashMap<>();
		List<Integer> list =new ArrayList<>();
		traversal(root, map, list);
		int[] ret = new int[list.size()];
		for (int i = 0; i < list.size(); i++) {
			ret[i] = list.get(i);
		}
		return ret;
	}
	public int traversal(TreeNode root, Map<Integer, Integer> map, List<Integer> list) {
		if (root == null)
			return 0;
		int left = traversal(root.left, map, list);
		int right = traversal(root.right, map, list);
		
		int sum = left + right + root.val;
		int v = map.getOrDefault(sum, 0);
		v++;
		map.put(sum, v);
		if (v > max) {
			list.clear();
			list.add(sum);
			max = v;
		}else if (v == max) {
			list.add(sum);
		}
		return sum;
		
	}
}