otsu方法
讲解
C++实现
#include <iostream>
#include <vector>
#include <opencv2\opencv.hpp>
#include <opencv2\highgui.hpp>
using namespace std;
using namespace cv;
// get threshold by otsu
int GetOtsuThreshold(Mat& input_img)
{
int thres_val = 1;
vector<int> ihist(256, 0);
int n = 0, n1, n2, Color = 0;
double m1, m2, sum, csum, fmax, sb;
for (int h = 0; h < input_img.rows; ++h)
{
for (int w = 0; w < input_img.cols; ++w)
{
++ihist[input_img.ptr<uchar>(h)[w]];
}
}
sum = csum = 0.0;
for (int k = 0; k <= 255; k++)
{
sum += (double)k* (double)ihist[k];
n += ihist[k];
}
// find otsu threshold
fmax = -1.0;
n1 = 0;
for (int k = 0; k<255; k++)
{
n1 += ihist[k];
if (n1 == 0) // prevent n1 from 0 (compute m1)
continue;
n2 = n - n1;
if (n2 == 0)
break;
csum += (double)k*ihist[k];
m1 = csum / n1;
m2 = (sum - csum) / n2;
// compute current g
sb = (double)n1* (double)n2* (m1 - m2) * (m1 - m2);
// update threshold
if (sb > fmax)
{
fmax = sb;
thres_val = k;
}
}
return thres_val;
}
// binarize image with threshold
void BinarizeImageByThreshold( Mat& input, Mat& output, int threshold )
{
assert( !input.empty() );
output = Mat::zeros( input.size(), CV_8UC1 );
for (int h = 0; h < output.rows; ++h)
{
for (int w = 0; w < output.cols; ++w)
{
if (input.ptr<uchar>(h)[w] > threshold)
{
output.ptr<uchar>(h)[w] = 255;
}
}
}
}
// read image with opencv
bool ReadImage(string fp, Mat& img)
{
img = imread( fp, IMREAD_GRAYSCALE );
return !img.empty();
}
int main(int argc, char**argv)
{
string img_fp = "img/11.png";
Mat input_img;
if ( !ReadImage( img_fp, input_img ) )
{
cout << "can not read image, exit!" << endl;
return -1;
}
int thres = GetOtsuThreshold( input_img );
Mat bw_img;
BinarizeImageByThreshold(input_img, bw_img, thres);
imshow("input", input_img);
imshow("output", bw_img);
waitKey();
destroyAllWindows();
return 0;
}
python实现
import cv2
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
def ReadImage( fp ):
return cv2.imread( fp, cv2.IMREAD_GRAYSCALE )
def Otsu( img ):
thres = 0
m1 = 0
m2 = 0
h_sum = 0
csum = 0
n1 = 0
n2 = 0
fmax = 0
curr_f = 0
ihist = [0] * 256
height, width = img.shape
n = height * width
for ii in range( height ):
for jj in range(width):
ihist[ img[ii,jj] ] += 1
for ii in range( 256 ):
h_sum += ii * ihist[ii]
n1 = 0
for k in range( 255 ):
n1 += ihist[k]
n2 = n - n1
if n1 == 0:
continue
n2 = n - n1
if n2 == 0:
break
csum += k*ihist[k];
m1 = csum / n1;
m2 = (h_sum - csum)
curr_f = 1.0 * n1 * n2 * (m1-m2)**2
if( curr_f > fmax ):
fmax = curr_f
thres = k
return thres
def Binarize( img, thres ):
return img > thres
img = ReadImage( "test01.png" )
thres = Otsu( img )
bw = Binarize( img, thres )
print( thres )
plt.imshow( bw )
plt.show()