/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int rob(TreeNode* root) {
if (!root)
{
return 0;
}
int maxExcludeRoot = 0;
return dfs(root, maxExcludeRoot);
}
private:
int dfs(TreeNode* root, int& maxExcludeRoot) {
if (!root)
{
maxExcludeRoot = 0;
return 0;
}
int maxExcludeRootRight = 0;
int maxExcludeRootLeft = 0;
//Caculate the max value of right, left subtree
int maxLeft = dfs(root->left, maxExcludeRootLeft);
int maxRight = dfs(root->right, maxExcludeRootRight);
//maxIncludeRoot strored value that current value + max value in last last layer
int maxIncludeRoot = root->val + maxExcludeRootLeft + maxExcludeRootRight;
//maxExcludeRoot strored max value in last layer
maxExcludeRoot = maxLeft + maxRight;
return max(maxIncludeRoot, maxExcludeRoot);
}
};