(版本为3.4.0)
opencv官方api文档:https://docs.opencv.org/
ml模块的svm操作:
python版本
1、生成训练数据
训练文件分别以类标签为文件名,里面存放对应的类文件
def generate_data(self,file_dir):
train_data= []
train_labels = []
if os.path.exists(file_dir):
file_list = os.listdir(file_dir)
for fl in file_list:
class_dir = os.path.join(file_dir,fl)
if os.path.isdir(class_dir):
filenames = os.listdir(class_dir)
for f in filenames:
img_name = os.path.join(class_dir,f)
img = cv2.imread(img_name)
img = cv2.resize(img,self.resize,interpolation=cv2.INTER_CUBIC)
new_img = img.reshape((1,self.resize[0]*self.resize[1]*3))
train_data.append(new_img[0])
train_labels .append(int(fl))
return (train_data,train_labels )
2、训练
def svmtrain(train_data,train_labels):
# 创建分类器
svm = cv2.ml.SVM_create()
svm.setType(cv2.ml.SVM_C_SVC) # SVM类型
svm.setKernel(cv2.ml.SVM_LINEAR) # 使用线性核
svm.setC(1.0)
train = np.array(train_data,np.float32)
train_labels = np.array(train_labels,np.int32)
train_labels = train_labels.reshape((train_labels.size,1))
# 训练
ret = svm.train(train, cv2.ml.ROW_SAMPLE, train_labels)
svm.save("svm_data.dat")
3、测试
def svmtest(model_path,test_file,resize):
svm = cv2.ml.SVM_load(model_path)
test_data = []
img = cv2.imread(test_file)
img = cv2.resize(img,resize,interpolation=cv2.INTER_CUBIC)
new_img = img.reshape((1,resize[0]*resize[1]*3))
test_data .append(new_img[0])
(ret, res) = svm.predict(test_data )
for i,r in enumerate(res):
text = ""
text = text + str(int(r[0]))
label = self.img_labels[i]
cv2.putText(img,'result:'+text,(0,20),cv2.FONT_HERSHEY_COMPLEX,1,(0,0,255),1)
cv2.putText(img,'label:'+str(label),(0,50),cv2.FONT_HERSHEY_COMPLEX,1,(0,0,255),1)
ff = os.path.basename(test_file)
cv2.imwrite("out"+ff,img)
c++版本
扫描二维码关注公众号,回复:
3565907 查看本文章
1、生成训练数据
void getFiles(string path, vector<string>& files, vector<int> &trainingLabels, int &label, vector<Mat>& trainingImages, Size dsize)
{
DIR *p_dir;
// path = path.append("/");
const char* str = path.c_str();
p_dir = opendir(str);
if( p_dir == NULL)
{
cout<< "can't open :" << path << endl;
}
struct dirent *p_dirent;
while ( p_dirent = readdir(p_dir))
{
string tmpFileName = p_dirent->d_name;
if( tmpFileName == "." || tmpFileName == "..")
{
continue;
}
else
{
cout<<"===========================02"<<endl;
string filepath = path + + "/" + tmpFileName;
cout<<"filename:"<<filepath<<endl;
char const* filename = filepath.data();
struct stat s_buf;
/*获取文件信息,把信息放到s_buf中*/
stat(filename, &s_buf);
if(S_ISDIR(s_buf.st_mode))
{
cout<<"===========================03"<<endl;
label = atoi(tmpFileName.c_str());
getFiles(filepath,files, trainingLabels, label, trainingImages, dsize);
}/*若输入的文件路径是普通文件,则打印并退出程序*/
else if(S_ISREG(s_buf.st_mode))
{
cout<<"===========================04"<<endl;
files.push_back(filepath);
cout<<"===========================05"<<endl;
Mat SrcImage=imread(filepath);
cout<<"===========================06"<<endl;
// 缩小图像
Mat DstImage;
resize(SrcImage, DstImage, dsize, 0, 0, INTER_LINEAR);
DstImage= DstImage.reshape(1, 1);
cout<<"===========================07"<<endl;
trainingImages.push_back(DstImage);
cout<<"===========================08"<<endl;
trainingLabels.push_back(label);
cout<<"===========================09"<<endl;
}
}
}
closedir(p_dir);
}
2、svm训练
int main()
{
//获取训练数据
Mat classes;
// Mat trainingData;
vector<Mat> trainingImages;
vector<int> trainingLabels;
int label = 0;
string path = "";
string model_path = "";
vector<string> files;
Size dsize = Size(20, 20);
cout<<"===========================01"<<endl;
getFiles(path, files, trainingLabels, label, trainingImages, dsize);
// getFiles(string path, vector<string>& files, vector<int> &trainingLabels, const int &label, Mat& trainingImages)
// get_1(trainingImages, trainingLabels);
// get_0(trainingImages, trainingLabels);
Mat trainingData(trainingImages.size(), trainingImages[0].cols, CV_32FC1);
for (int i = 0; i < trainingImages.size(); i++)
{
Mat temp(trainingImages[i]);
temp.copyTo(trainingData.row(i));
}
trainingData.convertTo(trainingData, CV_32FC1);
Mat(trainingLabels).copyTo(classes);
// classes.convertTo(classes, CV_32SC1);
//配置SVM训练器参数
Ptr<SVM> model = SVM::create();//以下是设置SVM训练模型的配置
model->setType(SVM::C_SVC);
model->setKernel(SVM::LINEAR);
model->setGamma(1);
model->setC(1);
model->setCoef0(0);
model->setNu(0);
model->setP(0);
model->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 20000, 0.0001));
Ptr<TrainData> tdata = TrainData::create(trainingData, ROW_SAMPLE, classes);
//model->train(trainingData, ROW_SAMPLE, classes);
model->train(tdata);
model->save(model_path);//保存
// svm.save(model_path);
cout<<"训练好了!!!"<<endl;
// getchar();
return 0;
}
3、svm测试
/**
Linux下扫描文件夹, 获得文件夹下的文件名
*/
int scanFiles(vector<string> &fileList, string inputDirectory)
{
inputDirectory = inputDirectory.append("/");
DIR *p_dir;
const char* str = inputDirectory.c_str();
p_dir = opendir(str);
if( p_dir == NULL)
{
cout<< "can't open :" << inputDirectory << endl;
}
struct dirent *p_dirent;
while ( p_dirent = readdir(p_dir))
{
string tmpFileName = p_dirent->d_name;
if( tmpFileName == "." || tmpFileName == "..")
{
continue;
}
else
{
fileList.push_back(tmpFileName);
}
}
closedir(p_dir);
return fileList.size();
}
int main(int argc, char** argv) {
//读取文件夹下所有文件
string file_path;
string out_path;
string svm_model_path;
vector<string> files;
int size = scanFiles(files, file_path);
//加载svm模型
Ptr<SVM> model = SVM::load(svm_model_path);
for (int i = 0;i < size;i++)
{
string filename = file_path + '/' + files[i].c_str();
string out_filename = out_path + '/' + files[i].c_str();
// cout<<"filename:"<<filename<<endl;
Mat img = imread(filename);
Mat DestImage;
DestImage= img.reshape(1, 1);
DstImage.convertTo(DstImage, CV_32FC1);
float response = model->predict(DstImage);
}
return 0;
}