本次实现的红黑树为基于上篇实现的普通二叉树之上实现的,普通二叉树的实现已经在https://blog.csdn.net/qq811299838/article/details/104038745这篇文章中列出,此处就不再放出来了。
由于网上已经有很多红黑树的算法介绍,本文将不再介绍算法,只提供代码实现以及测试代码,如有问题,欢迎指出!
关于结点删除:选用被删除结点中,左子树深度比右子树大,则选择左子树中最大结点作为替换结点;否则,选择右子树中最小结点作为替换结点。
编译环境:GCC 7.3、vs 2005
红黑树的代码如下:
#ifndef __RBTREE_H__
#define __RBTREE_H__
#if __cplusplus >= 201103L
#include <type_traits> // std::forward、std::move
#endif
#if __cplusplus >= 201103L
#define null nullptr
#else
#define null NULL
#endif
#include "btree.h"
template<typename _Tp>
struct Comparator
{
int operator()(const _Tp &a1, const _Tp &a2)
{
if(a1 < a2) return 1;
if(a2 < a1) return -1;
return 0;
}
};
template<typename _Tp, typename _Compare = Comparator<_Tp>>
class RBTree
{
public:
typedef _Tp value_type;
typedef _Tp & reference;
typedef _Tp * pointer;
typedef const _Tp & const_reference;
typedef unsigned long size_type;
typedef _Compare compare_type;
#if __cplusplus >= 201103L
typedef _Tp && rvalue_reference;
#endif
private:
enum COLOR
{
RED,
BLACK
};
public:
typedef BinaryTree<value_type> tree_type;
typedef typename tree_type::node_type node_type;
typedef typename node_type::color_type color_type;
private:
void _M_swap_color(node_type *node1, node_type *node2)
{
if(null == node1 && null != node2)
{ node2->_M_color = COLOR::BLACK; }
else if(null == node2 && null != node1)
{ node1->_M_color = COLOR::BLACK; }
else if(null != node1 && null != node2)
{
color_type c1 = node1->color();
node1->_M_color = node2->color();
node2->_M_color = c1;
}
}
public:
typedef node_type*(*iterator_func)(node_type*);
template<iterator_func _Next, iterator_func _Prev>
struct iterator_impl
{
node_type *_M_node;
iterator_impl(node_type *node = null)
: _M_node(node) { }
iterator_impl operator++()
{ return iterator_impl(_M_node = _Next(_M_node)); }
iterator_impl operator++(int)
{
iterator_impl ret(_M_node);
_M_node = _Next(_M_node);
return ret;
}
iterator_impl operator--()
{ return iterator_impl(_M_node = _Prev(_M_node)); }
iterator_impl operator--(int)
{
iterator_impl ret(_M_node);
_M_node = _Prev(_M_node);
return ret;
}
reference operator*()
{ return *_M_node->value(); }
pointer operator->()
{ return _M_node->value(); }
bool operator==(const iterator_impl &it) const
{ return _M_node == it._M_node; }
bool operator!=(const iterator_impl &it) const
{ return _M_node != it._M_node; }
};
template<iterator_func _Next, iterator_func _Prev>
struct const_iterator_impl
{
const node_type *_M_node;
const_iterator_impl(const node_type *node = null)
: _M_node(node) { }
const_iterator_impl(const iterator_impl<_Next, _Prev> &it)
: _M_node(it._M_node) { }
const_iterator_impl operator++()
{ return const_iterator_impl(_M_node = _Next(const_cast<node_type*>(_M_node))); }
const_iterator_impl operator++(int)
{
const_iterator_impl ret(_M_node);
_M_node = _Next(const_cast<node_type*>(_M_node));
return ret;
}
const_iterator_impl operator--()
{ return const_iterator_impl(_M_node = _Prev(const_cast<node_type*>(_M_node))); }
const_iterator_impl operator--(int)
{
const_iterator_impl ret(_M_node);
_M_node = _Prev(const_cast<node_type*>(_M_node));
return ret;
}
reference operator*()
{ return *_M_node->value(); }
pointer operator->()
{ return _M_node->value(); }
bool operator==(const const_iterator_impl &it) const
{ return _M_node == it._M_node; }
bool operator!=(const const_iterator_impl &it) const
{ return _M_node != it._M_node; }
};
public:
typedef iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> iterator;
typedef iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> reverse_iterator;
typedef const_iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> const_iterator;
typedef const_iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> const_reverse_iterator;
public:
RBTree() { }
RBTree(const RBTree &t)
: _M_tree(t._M_tree) { }
#if __cplusplus >= 201103L
RBTree(RBTree &&t)
: _M_tree(std::move(t._M_tree)) { }
#endif
size_type size() const
{ return _M_tree.size(); }
size_type depth() const
{ return _M_tree.depth(); }
const tree_type& get_tree() const
{ return _M_tree; }
bool empty() const
{ return size() == 0; }
iterator begin()
{ return iterator(_M_tree.left_child_under(_M_tree.root())); }
iterator end()
{ return iterator(); }
const_iterator begin() const
{ return const_iterator(_M_tree.left_child_under(_M_tree.root())); }
const_iterator end() const
{ return const_iterator(); }
const_iterator cbegin() const
{ return begin(); }
const_iterator cend() const
{ return end(); }
reverse_iterator rbegin()
{ return reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }
reverse_iterator rend()
{ return reverse_iterator(); }
const_reverse_iterator rbegin() const
{ return const_reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }
const_reverse_iterator rend() const
{ return const_reverse_iterator(); }
const_reverse_iterator crbegin() const
{ return rbegin(); }
const_reverse_iterator crend() const
{ return rend(); }
void erase(const_iterator it)
{
if(it == end())
{ return; }
node_type *node = const_cast<node_type*>(it._M_node);
node_type *reserve = _M_tree.get_erase_reserve(node);
node->swap(reserve);
if(reserve->color() != COLOR::RED)
{
_M_erase_adjust(reserve);
_M_swap_color(node, reserve);
}
_M_tree.erase_leaf(reserve);
}
iterator insert(const_reference v)
{ return _M_insert(v); }
#if __cplusplus >= 201103L
iterator insert(rvalue_reference v)
{ return _M_insert(std::move(v)); }
#endif
iterator find(const_reference v)
{ return _M_find<value_type, compare_type>(v); }
template<typename _CompareType>
const_iterator find(const_reference v) const
{ return _M_find<value_type, _CompareType>(v); }
private:
iterator _M_insert(const_reference v)
{
iterator found;
if(_M_find_and_insert<value_type, compare_type>(v, found))
{ *found = v; }
return found;
}
#if __cplusplus >= 201103L
iterator _M_insert(rvalue_reference v)
{
iterator found;
if(_M_find_and_insert<value_type, compare_type>(std::move(v), found))
{ *found = v; }
return found;
}
#endif
template<typename _InputType, typename _CompareType>
iterator _M_find(const _InputType &input)
{
node_type *node = _M_tree.root();
while(null != node)
{
int res = _CompareType()(input, *node->value());
if(res == 0)
{ return iterator(node); }
if(res > 0)
{ node = node->left_child(); }
else
{ node = node->right_child(); }
}
return iterator();
}
void _M_erase_adjust(node_type *node)
{
// 父结点、兄弟结点、远侄子、近侄子
// 当删除结点是左孩子时,兄弟结点的左孩子就是近侄子,兄弟结点的右孩子就是远侄子
node_type *parent = null, *brother = null, *far_nephew = null, *near_nephew = null;
while(null != node && node->color() != COLOR::RED)
{
// 先找到关系
parent = node->parent();
if(null == parent) { }
else if(parent->left_child() == node)
{
brother = parent->right_child();
if(null != brother)
{
far_nephew = brother->right_child();
near_nephew = brother->left_child();
}
}
else
{
brother = parent->left_child();
if(null != brother)
{
far_nephew = brother->left_child();
near_nephew = brother->right_child();
}
}
// 当兄弟结点为红色结点时
if(null != brother && brother->color() == COLOR::RED)
{
brother->_M_color = COLOR::BLACK;
parent->_M_color = COLOR::RED;
if(parent->left_child() == node)
{ _M_tree.left_rotate(parent); }
else
{ _M_tree.right_rotate(parent); }
}
// 当远侄子结点为红色结点时
else if(null != far_nephew && far_nephew->color() == COLOR::RED)
{
brother->_M_color = parent->color();
parent->_M_color = COLOR::BLACK;
far_nephew->_M_color = COLOR::BLACK;
if((null == parent ? null : parent->left_child()) == node)
{ _M_tree.left_rotate(parent); }
else
{ _M_tree.right_rotate(parent); }
break;
}
// 当近侄子结点为红色结点时
else if(null != near_nephew && near_nephew->color() == COLOR::RED)
{
brother->_M_color = COLOR::RED;
near_nephew->_M_color = COLOR::BLACK;
if(parent->left_child() == node)
{ _M_tree.right_rotate(brother); }
else
{ _M_tree.left_rotate(brother); }
}
// 当兄弟结点的孩子结点都是黑色结点时
else
{
if(null != brother)
{ brother->_M_color = COLOR::RED; }
if(null != parent && parent->color() == COLOR::RED)
{
parent->_M_color = COLOR::BLACK;
break;
}
node = parent;
}
}
}
void _M_insert_adjust(node_type *node)
{
while(null != node)
{
node_type *parent = node->parent();
node_type *grand_parent = null == parent ? null : parent->parent();
node_type *uncle = null;
if((null == grand_parent ? null : grand_parent->left_child()) == parent)
{ uncle = null == grand_parent ? null : grand_parent->right_child(); }
else
{ uncle = null == grand_parent ? null : grand_parent->left_child(); }
// 当父结点是红色结点时
if(null != parent && parent->color() == COLOR::RED)
{
// 当叔叔结点是红色结点时
if(null != uncle && uncle->color() == COLOR::RED)
{
parent->_M_color = COLOR::BLACK;
uncle->_M_color = COLOR::BLACK;
grand_parent->_M_color = COLOR::RED;
node = grand_parent;
continue;
}
// 当新结点是左孩子时
if((null == parent ? null : parent->left_child()) == node)
{
// 父结点是左孩子
if((null == grand_parent ? null : grand_parent->left_child()) == parent)
{
_M_swap_color(parent, grand_parent);
_M_tree.right_rotate(grand_parent);
break;
}
else
{
_M_tree.right_rotate(parent);
node = parent;
continue;
}
}
// 新结点是右孩子
else
{
// 父结点是左孩子
if((null == grand_parent ? null : grand_parent->left_child()) == parent)
{
_M_tree.left_rotate(parent);
node = parent;
continue;
}
else
{
_M_swap_color(parent, grand_parent);
_M_tree.left_rotate(grand_parent);
continue;
}
}
}
break;
}
_M_tree.root()->_M_color = COLOR::BLACK;
}
/* 插入结点,如果结点不存在则插入新结点
* @input 插入的值
* @result 结点的迭代器
* @return 如果结点本来已存在,则返回true
*/
template<typename _InputType, typename _CompareType>
bool _M_find_and_insert(const _InputType &input, iterator &result)
{
if(empty())
{
result._M_node = _M_tree.append_root(input);
return false;
}
node_type *node = _M_tree.root();
while(true)
{
int res = _CompareType()(input, *node->value());
if(res == 0)
{
result._M_node = node;
return true;
}
if(res > 0)
{
if(null == node->left_child())
{
node = _M_tree.append_left(node, input);
break;
}
node = node->left_child();
}
else
{
if(null == node->right_child())
{
node = _M_tree.append_right(node, input);
break;
}
node = node->right_child();
}
}
_M_insert_adjust(node);
result._M_node = node;
return false;
}
#if __cplusplus >= 201103L
template<typename _InputType, typename _CompareType>
bool _M_find_and_insert(_InputType &&input, iterator &result)
{
if(empty())
{
result._M_node = _M_tree.append_root(input);
return false;
}
node_type *node = _M_tree.root();
while(true)
{
int res = _CompareType()(std::forward<value_type>(input), *node->value());
if(res == 0)
{
result._M_node = node;
return true;
}
if(res > 0)
{
if(null == node->left_child())
{
node = _M_tree.append_left(node, std::move(input));
break;
}
node = node->left_child();
}
else
{
if(null == node->right_child())
{
node = _M_tree.append_right(node, std::move(input));
break;
}
node = node->right_child();
}
}
_M_insert_adjust(node);
result._M_node = node;
return false;
}
#endif
private:
tree_type _M_tree;
};
#endif // __RBTREE_H__
测试代码:
#include <iostream>
#include <list>
#include "rbtree.h"
#ifdef _WIN32
#include <windows.h>
#endif
#if __cplusplus < 201103L
#include <sstream>
#endif
#define MAX_NUMBER_BIT 5
static std::string get_string(int v)
{
#if __cplusplus < 201103L
std::stringstream ss;
ss << v;
std::string tmp = ss.str();
#else
std::string tmp = std::to_string(v);
#endif
std::string result = "";
for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
{ result += ' '; }
result += tmp;
for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
{ result += ' '; }
return result;
}
static void set_red()
{
#ifdef _WIN32
SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_RED);
#elif defined(__linux__)
std::cout << "\033[31m";
#endif
}
static void set_black()
{
#ifdef _WIN32
SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_BLUE);
#elif defined(__linux__)
std::cout << "\033[34m";
#endif
}
static void set_default()
{
#ifdef _WIN32
SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_GREEN);
#elif defined(__linux__)
std::cout << "\033[32m";
#endif
}
struct T
{
T(int v = 0) : value(v) { }
void print()
{ std::cout << get_string(value); }
bool operator<(const T& t) const
{ return value < t.value; }
int value;
};
typedef RBTree<T> TestTree;
typedef TestTree::node_type NodeType;
static void print_tree(const BinaryTree<T> &tree)
{
std::list<const NodeType*> s;
s.push_back(tree.root());
bool break_flag = false;
while(!break_flag)
{
break_flag = true;
std::size_t count = s.size();
int print_count = 0;
while(count-- > 0)
{
const NodeType *t = s.front();
s.pop_front();
if(null == t)
{
set_black();
T().print();
set_default();
s.push_back(null);
s.push_back(null);
}
else
{
if(t->color() == 0)
{ set_red(); }
else
{ set_black(); }
t->value()->print();
set_default();
s.push_back(t->left_child());
s.push_back(t->right_child());
if(break_flag)
{ break_flag = null == t->left_child() && null == t->right_child(); }
}
if(++print_count % 2 == 0)
{ std::cout << "|"; }
}
std::cout << std::endl;
}
}
void main_func()
{
TestTree avl;
std::cout << "insert: ---->" << avl.insert(T(10))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(30))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(1))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(31))->value << std::endl;
print_tree(avl.get_tree());
TestTree::iterator it1 = avl.insert(T(32));
std::cout << "insert: ---->" << it1->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(33))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(34))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(35))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(36))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(37))->value << std::endl;
print_tree(avl.get_tree());
std::cout << "insert: ---->" << avl.insert(T(38))->value << std::endl;
std::cout << "size: " << avl.size() << std::endl;
std::cout << "depth: " << avl.depth() << std::endl;
print_tree(avl.get_tree());
std::cout << "--------------erase node 32----------------" << std::endl;
avl.erase(it1);
print_tree(avl.get_tree());
std::cout << "--------------erase node 35----------------" << std::endl;
avl.erase(avl.find(T(35)));
print_tree(avl.get_tree());
std::cout << std::endl << "--------------iterator visit-----------" << std::endl;
for(TestTree::const_iterator it = avl.begin(); it != avl.end(); ++it)
{
std::cout << it->value << ' ';
}
std::cout << std::endl;
}
int main()
{
set_default();
main_func();
system("pause");
return 0;
}
测试结果: