这里不介绍归并排序的具体算法,不懂的可以看百度百科。
编译环境:C++11
代码实现如下:
#include <type_traits> // std::declval
typedef unsigned long size_type;
template<typename _Tp>
constexpr size_type distance(const _Tp &_1, const _Tp &_2)
{ return _1 < _2 ? _2 - _1 : _1 - _2; }
template<typename T1, typename T2>
struct Comparator2
{
int operator()(const T1 &arg1, const T2 &arg2) const
{
if(arg1 < arg2) return 1;
if(arg2 < arg1) return -1;
return 0;
}
};
/* 将两个已排序的范围合并成一个已排序的范围
* 范围:[beg, end)
* @dst 用于保存合并后的结果,须支持前向迭代
* [@src_beg, @src_end)为已排序的范围
*/
template<typename _ForwardIter1,
typename _ForwardIter2 = _ForwardIter1,
typename _OutputForwardIter = _ForwardIter1,
typename _Compare = Comparator2<
decltype(*std::declval<_ForwardIter1>()),
decltype(*std::declval<_ForwardIter2>())> >
void merge(_OutputForwardIter dst,
_ForwardIter1 src1_beg, _ForwardIter1 src1_end,
_ForwardIter2 src2_beg, _ForwardIter2 src2_end,
_Compare c = _Compare())
{
while(src1_beg != src1_end && src2_beg != src2_end)
{ *dst++ = c(*src2_beg, *src1_beg) > 0 ? *src2_beg++ : *src1_beg++; }
while(src1_beg != src1_end)
{ *dst++ = *src1_beg++; }
while(src2_beg != src2_end)
{ *dst++ = *src2_beg++; }
}
/* 归并排序
* @result 保存排序结果
* [@beg, @end) 待排序的范围,须支持随机迭代
*/
template<typename _RandomIter,
typename _ResultRandomIter = _RandomIter,
typename _Compare = Comparator<decltype(*std::declval<_RandomIter>())>>
void msort(_ResultRandomIter result,
_RandomIter beg,
_RandomIter end,
_Compare c = _Compare())
{
size_type center = distance(beg, end) / 2;
if(center > 0)
{
msort(result, beg, beg + center, c);
msort(result + center, beg + center, end, c);
merge(result, beg, beg + center, beg + center, end, c);
for(_RandomIter it = beg; it != end; ++it)
{ *it = *result++; }
}
}
/* 归并排序
* [@beg, @end) 待排序的范围,结果将覆盖这个范围,须支持随机迭代
*/
template<typename _RandomIter,
typename _Compare = Comparator<decltype(*std::declval<_RandomIter>())>>
void msort(_RandomIter beg,
_RandomIter end,
_Compare c = _Compare())
{
using _IterValueType = decltype(*std::declval<_RandomIter>());
using _RemoveConstType = typename std::remove_const<_IterValueType>::type;
using _ValueType = typename std::remove_reference<_RemoveConstType>::type;
size_type dist = distance(beg, end);
_ValueType *result = new _ValueType[dist];
msort(result, beg, end, c);
for(size_type i = 0; i < dist; ++i)
{ *beg++ = result[i]; }
delete[] result;
}
如有问题,欢迎指出!