这几天看算法导论,看到矩阵一章,就实现了一下。
下面是普通的矩阵乘法,复杂度为:n^3。
template<unsigned M,unsigned N, unsigned Q>
void Square_matrix_multiply(int(&A)[M][N], int(&B)[N][Q], int(&C)[M][Q]) {
for (size_t i = 0;i != M;++i) {
for (size_t j = 0;j != Q;++j) {
C[i][j] = 0;
for (size_t n = 0;n != N;++n) {
C[i][j] += A[i][n] * B[n][j];
}
}
}
}
函数接受三个二维数组,A * B得到的矩阵赋值给C。
下面是分治策略的算法。
template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
size_t n = A.rows();
Matrix C(n, n);
if (n == 1)
return C = A.get()*B.get();
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2); // 使用一个类MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2); // 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2); // 进行分割
C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21); // Matrix::operator+;
C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22); // MatrixRef::operator=;
C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
}
return C;
}
矩阵实现了一个Matrix类(具体实现在最下面),有一个构造函数:接受两个size_t值l、r,生成l*r大小值全为0的矩阵。
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
data->resize(l*r);
}
其中hight为矩阵行高,width为列宽,data为shared_ptr,矩阵用vector实现。
A.rows()返回A的width长度(即方矩阵的边长),Matrix(n,n)创建一个矩阵。
size_t rows() const {
return width;
}
如果n==1,通过Matrix的get函数返回第一个元素,也就是唯一的一个元素。
int Matrix::get() const {
return (*data)[0];
}
为了不复制矩阵元素(如果可以复制矩阵元素的话,会简单很多),另实现了一个MatrixRef,其含有:两个size_t数据成员(实现坐标点)、一个size_t数据成员(实现矩阵长度)、一个weak_ptr(指向vector<int>)。
MatrixRef含有两个构造函数:一个接受Matrix加两个size_t;一个接受MatrixRef加两个size_t。都是为了指明引用范围。
MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data),
hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr),
hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
wptr用data或wptr初始化,避免拷贝。length为rows()的返回值除以2,因为是分割为4个矩阵,行列各除以2。
要注意:接受MatrixRef的坐标要加上之前的坐标。
MatrixRef也有一个rows成员函数,为了递归调用。
size_t rows() const {
return length;
}
Square_max_matrix_multiply_recursive函数返回一个Matrix,Matrix实现了operator+,但是行列必须相等。
Matrix& Matrix::operator+=(const Matrix &rhs) {
if (hight == rhs.hight && width == rhs.width) {
for (size_t i = 0;i != size();++i)
(*data)[i] += (*rhs.data)[i];
}
else
throw std::logic_error("Not Matched");
return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
MatrixRef实现了一个operator=。
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {
for (size_t i = 0;i != length;++i) {
for (size_t j = 0;j != length;++j) {
(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] =
rhs.get(i + 1, j + 1); //注意:length*2 因为C也被分割了
}
}
return *this;
}
其中(i + hight_startptr)*length * 2 + j + width_startptr)为当前下标(vector对应矩阵的下标,非矩阵行列)。此函数将分割的C进行“拼合”。注意:length * 2 ,因为C也被分割了,不乘以2为C_11及C_12的长度,乘以2才是C的行列长宽,才能给C的给定位置赋值。
下面是Strassen矩阵算法。
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 为2的幂的情况下
size_t n = A.rows();
Matrix C(n, n);
if (n == 1) {
return C = A.get()*B.get();
}
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2); // 使用一个类MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2); // 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2); // 进行分割
Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、减
S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
C_11 = P5 + P4 - P2 + P6;
C_12 = P1 + P2;
C_21 = P3 + P4;
C_22 = P5 + P1 - P3 - P7;
}
return C;
}
此算法较之前多了一个MatrixRef::operator-、以及MatrixRef::operator+。
Matrix& Matrix::operator-() {
for (auto &f : *data)
f = -f;
return *this;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml = -mr + ml;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
operator-用Matrix的取负、以及Matrix的加法,同时最重要的还有Matrix(const MatrixRef &)。MatrixRef将此对象引用范围内的子矩阵创建一个局部Matrix对象。
operator+用Matrix的加法与Matrix(const MatrixRef &)。
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩阵长度
auto ivec = *rhs.wptr.lock();
for (size_t i = 0; i != hight; ++i) {
for (size_t j = 0; j != width; ++j) {
data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
}
}
}
其中max_size为wptr所指的vector<int>的size,进行根号得到。max_size就是MatrixRef对象未分解(即未分割的C)的矩阵边长。static_cast把sqrt返回的double转未size_t,因为是方矩阵,所以不会损失精度。
(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr)为MatrixRef对象引用范围内对应vector的下标。此必需乘以max_size。
下面为不是2的幂的情况。
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
size_t n = A.rows();
double size = log(n) / log(2);
size_t l_size = static_cast<size_t>(size);
if (l_size != size) {
size_t t_size = (l_size + 1)*(l_size + 1);
Matrix a(t_size, t_size), b(t_size, t_size);
a = A;
b = B;
Matrix C = Strassen_matrix_fit(a, b);
Matrix c(n, n);
c = C;
return c;
}
else
return Strassen_matrix_fit(A, B);
}
size与l_size比较可知是否为2的幂,如果是,执行else,不是,则执行if。
当不是2的幂时,我的思路是把它加0,拼成2的幂的形式。如下图。
1 2 3 1 2 3 0 5 6 7
2 3 2 ---> 2 3 2 0 ---> 4 5 2
3 2 1 3 2 1 0 3 5 6
0 0 0 0
然后得出结果时再切去周围的零,其值是不变的。假如为n*n的矩阵,复杂度(n + k) ^ lg7。n + k 为最接近n的2的幂,其中0<k<n。
(n + k) ^ lg7 < (2n) ^ lg7 = 7 * n ^ lg7。
复杂度还是n ^ lg7。
加零还是切去零,我是通过赋值来实现的。
Matrix& Matrix::operator=(const Matrix &rhs) {
if (hight == rhs.hight) { // rhs this
for (size_t i = 0; i != size(); ++i) { // 1 2 3 1 2 3
(*data)[i] = (*rhs.data)[i]; // 2 3 2 -> 2 3 2
} // 3 2 1 3 2 1
}
else if (hight > rhs.hight) { // 1 2 3 1 2 3 0
for (size_t i = 0;i != hight; ++i) { // 2 3 2 -> 2 3 2 0
for (size_t j = 0, n = 1;j != width; ++j) { // 3 2 1 3 2 1 0
if (j >= rhs.width || i >= rhs.hight) // 0 0 0 0
(*data)[i * width + j] = 0;
else
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];
}
}
}
else { // 1 2 3 4 1 2 3
for (size_t i = 0;i != hight; ++i) { // 2 3 4 3 -> 2 3 4
for (size_t j = 0;j != width; ++j) { // 3 4 3 2 3 4 3
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; // 4 3 2 1
}
}
}
return *this;
}
有三种复制方式:当左矩阵边长与右矩阵边长相等,第一个,正常赋值。左>右时,左上角对其,剩余的赋0。左<右时,左上角对其,多余的切掉。
补充:
在我的电脑上,普通(n^3)的算法与Strassen算法在1500*1500左右的时候时间是差不多的,但是耗时达到30秒,之后Strassen算法会出现明显的优势。在小于100*100的矩阵乘法时普通算法耗时小于0.01秒,而Strassen可达到3秒,普通算法有绝对的优势。
在我的电脑上,把二维数组扩展到300*300以上时会有栈溢出,这时可以上网搜索找到相应的解决办法。
END
设置MatirxRef类不知道好不好,毕竟矩阵拷贝也不影响复杂度。
肯定有很多值得改进的地方,也有不对的地方,可以评论提醒一下。
附:
Matrix头文件。
#ifndef MATRIX_H
#define MATRIX_H
#include<iostream>
#include<memory>
#include<vector>
class MatrixRef;
class Matrix {
friend Matrix operator+(const Matrix &, const Matrix &);
friend Matrix operator-(const Matrix &, const Matrix &);
friend std::ostream& operator<<(std::ostream&, const Matrix &);
friend Matrix operator+(const MatrixRef &, const MatrixRef &);
friend Matrix operator-(const MatrixRef &, const MatrixRef &);
friend class MatrixRef;
public:
Matrix();
template<unsigned M, unsigned N>
Matrix(int(&A)[M][N]) : hight(M), width(N), data(make_shared<vector<int>>()) {
data->reserve(M*N);
for (size_t i = 0;i != M;++i)
for (size_t j = 0;j != N;++j)
data->push_back(A[i][j]);
}
Matrix(size_t l, size_t r); // 创建一个行l、列r的零矩阵
Matrix(const Matrix &rhs); // 深层次拷贝构造
explicit Matrix(const MatrixRef &); // 将MatrixRef转换为Matrix
int& get(size_t l, size_t r); // 取得行L、列R的值
const int& get(size_t l, size_t r) const;
int get() const; // 得到第一个值
Matrix& operator=(const Matrix &rhs); // 深层次拷贝赋值
Matrix& operator+=(const Matrix &rhs);
Matrix& operator=(int i); // 把一个为i的值赋给行为1、列为1的矩阵
Matrix& operator-(); // 对矩阵取负
size_t rows() const {
return width;
}
size_t size() const {
return hight * width;
}
private:
void check_situation(size_t l, size_t r) const {
if (l > hight || r > width)
throw std::range_error("Invalid range");
}
size_t hight = 1;
size_t width = 1;
std::shared_ptr<std::vector<int>> data;
};
class MatrixRef {
friend Matrix operator-(const MatrixRef &, const MatrixRef &);
friend Matrix operator+(const MatrixRef &, const MatrixRef &);
friend class Matrix;
public:
MatrixRef(const Matrix &m, size_t line, size_t row);
MatrixRef(const MatrixRef &mref, size_t line, size_t row);
MatrixRef& operator=(const Matrix &rhs); // 对C_11、C_12、C_13、C_14进行赋值拼接的函数
int& get() const; // 得到第一个值
size_t rows() const {
return length;
}
private:
std::weak_ptr<std::vector<int>> wptr;
size_t hight_startptr;
size_t width_startptr;
size_t length;
};
Matrix operator+(const Matrix &, const Matrix &);
Matrix operator-(const Matrix &, const Matrix &);
Matrix operator+(const MatrixRef &, const MatrixRef &);
Matrix operator-(const MatrixRef &, const MatrixRef &);
std::ostream& operator<<(std::ostream&, const Matrix &);
template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
size_t n = A.rows();
Matrix C(n, n);
if (n == 1)
return C = A.get()*B.get();
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一个类MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 进行分割
C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);// Matrix::operator+;
C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);// MatrixRef::operator=;
C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
}
return C;
}
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 为2的幂的情况下
size_t n = A.rows();
Matrix C(n, n);
if (n == 1) {
return C = A.get()*B.get();
}
else {
MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一个类MatirxRef
MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 进行分割
Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、减
S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
C_11 = P5 + P4 - P2 + P6;
C_12 = P1 + P2;
C_21 = P3 + P4;
C_22 = P5 + P1 - P3 - P7;
}
return C;
}
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
size_t n = A.rows();
double size = log(n) / log(2);
size_t l_size = static_cast<size_t>(size);
if (l_size != size) {
size_t t_size = (l_size + 1)*(l_size + 1);
Matrix a(t_size, t_size), b(t_size, t_size);
a = A;
b = B;
Matrix C = Strassen_matrix_fit(a, b);
Matrix c(n, n);
c = C;
return c;
}
else
return Strassen_matrix_fit(A, B);
}
#endif
头文件实现。
#include"Matrix1.h"
#include<math.h>
using namespace std;
Matrix::Matrix() :data(make_shared<vector<int>>()) {
data->push_back(0);
}
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
data->resize(l*r);
}
Matrix::Matrix(const Matrix &rhs) : hight(rhs.hight), width(rhs.width), data(make_shared<vector<int>>()) {
for (size_t i = 0;i != size(); ++i)
data->push_back((*rhs.data)[i]);
}
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩阵长度
auto ivec = *rhs.wptr.lock();
for (size_t i = 0; i != hight; ++i) {
for (size_t j = 0; j != width; ++j) {
data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
}
}
}
int& Matrix::get(size_t l, size_t r) {
check_situation(l, r);
return (*data)[--l * width + --r];
}
const int& Matrix::get(size_t l, size_t r) const {
check_situation(l, r);
return (*data)[--l * width + --r];
}
int Matrix::get() const {
return (*data)[0];
}
Matrix& Matrix::operator=(const Matrix &rhs) {
if (hight == rhs.hight) { // rhs this
for (size_t i = 0; i != size(); ++i) { // 1 2 3 1 2 3
(*data)[i] = (*rhs.data)[i]; // 2 3 2 -> 2 3 2
} // 3 2 1 3 2 1
}
else if (hight > rhs.hight) { // 1 2 3 1 2 3 0
for (size_t i = 0;i != hight; ++i) { // 2 3 2 -> 2 3 2 0
for (size_t j = 0, n = 1;j != width; ++j) { // 3 2 1 3 2 1 0
if (j >= rhs.width || i >= rhs.hight) // 0 0 0 0
(*data)[i * width + j] = 0;
else
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];
}
}
}
else { // 1 2 3 4 1 2 3
for (size_t i = 0;i != hight; ++i) { // 2 3 4 3 -> 2 3 4
for (size_t j = 0;j != width; ++j) { // 3 4 3 2 3 4 3
(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; // 4 3 2 1
}
}
}
return *this;
}
Matrix& Matrix::operator+=(const Matrix &rhs) {
if (hight == rhs.hight && width == rhs.width) {
for (size_t i = 0;i != size();++i)
(*data)[i] += (*rhs.data)[i];
}
else
throw std::logic_error("Not Matched");
return *this;
}
Matrix& Matrix::operator=(int i) {
if (hight == width && hight == 1)
(*data)[0] = i;
return *this;
}
Matrix& Matrix::operator-() {
for (auto &f : *data)
f = -f;
return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
Matrix m(lhs);
return m += rhs;
}
Matrix operator-(const Matrix &lhs,const Matrix &rhs) {
Matrix m(rhs);
return m = -m + lhs;
}
MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr), hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {
for (size_t i = 0;i != length;++i) {
for (size_t j = 0;j != length;++j) {
(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = rhs.get(i + 1, j + 1); //注意:length*2 因为C也被分割了
}
}
return *this;
}
int& MatrixRef::get() const {
return (*wptr.lock())[static_cast<size_t>(hight_startptr*sqrt(wptr.lock()->size())) + width_startptr];
}
Matrix operator+(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml += mr;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
Matrix ml(lhs), mr(rhs);
return ml = -mr + ml;
}
ostream& operator<<(ostream &os, const Matrix &m) {
int i = 0;
for (auto f : *m.data) {
cout << f;
if (++i == m.width) {
std::cout << '\n';
i = 0;
}
else
cout << ' ';
}
return os;
}
END