【LittleXi】 N-gram模型(C++实现)

N-gram模型(C++实现)

定义:通俗地讲,就是利用前文的单词,来推算下一个最大概率出现的单词

马尔科夫性 (独立性假设)

但是前文或许有很多单词,这样编程很复杂,不妨讲比较远的单词抛弃掉,仅取最近的2or3个单词作为“提示词”,即一个单词的概率只取决于前面固定数量的单词。

本文采用的取的两个单词,建立的二元bigram模型,代码实现也非常简单~

代码实现

英文训练版本

#include<iostream>
#include<vector> 
#include<map>
#include<string>
#include<algorithm>
#include<fstream>
#pragma warning(disable:4996)
using namespace std;

map<pair<string, string>, map<string, int>> mp;
map<pair<string, string>, string> store_keyval;

void train()
{
    
    
	string s4;
	int cnt=0;
	string s1, s2, s3;
	ifstream inFile;
	inFile.open("train.txt");
	inFile >> s1 >> s2 >> s3;
	while (1)
	{
    
    
		inFile >> s4;			
		if (s4 == "my_over")
		{
    
    
			break;
		}
		s1 = s2;
		s2 = s3;
		s3 = s4;
		mp[{
    
    s1, s2}][s3]++;
	}
	inFile.close();
	for (auto& sssi : mp)
	{
    
    
		string s1 = sssi.first.first, s2 = sssi.first.second;
		vector<pair<string, int>> tv;
		for (auto& si : sssi.second)
			tv.push_back({
    
     si });

		//对出现概率进行排序
		sort(tv.begin(), tv.end(), [&](pair<string, int>& p1, pair<string, int>& p2){
    
    
			return p1.second > p2.second;
			});
		
		store_keyval[{
    
    s1, s2}] = tv[0].first;
	}
}

void test()
{
    
    
	int len = 0;
	cout << "请输入续写的长度:" << endl;
	cin >> len;
	cout << "请输入想要续写的内容" << endl;
	string s1, s2;
	cin >> s1 >> s2;
	for (int i = 0; i < len; i++)
	{
    
    
		string s3;
		if (store_keyval.find({
    
    s1, s2}) == store_keyval.end())
			s3 = "and";
		else
			s3 = store_keyval[ {
    
     s1, s2 }];
		cout << s3 << " ";
		s1 = s2;
		s2 = s3;
	}
	cout << endl;
}

int main()
{
    
    
	train();
	int test_time = 0;
	cout << "请输入需要询问的次数" << endl;
	cin >> test_time;
	while (test_time--)
	{
    
    
		test();
	}
}

中文训练版本

#include<iostream>
#include<vector> 
#include<map>
#include<string>
#include<algorithm>
#include<fstream>
#pragma warning(disable:4996)
using namespace std;


map<pair<string, string>, map<string, int>> mp;
map<pair<string, string>, string> store_keyval;

void train()
{
    
    
	string s4;
	int cnt=0;
	ifstream inFile;
	inFile.open("zh-train.txt");
	//inFile >> s1 >> s2 >> s3;
	//wstring s;
	while (1)
	{
    
    
		cnt++;
		inFile >> s4;			
		//cout << s4 << endl;
		if (s4 == "my_over")
		{
    
    
			break;
		}
		if (cnt % 100000==0)
			cout << cnt << endl;
		string s1, s2, s3;
		if (s4.size() <= 6)
			continue;
		s1 = s4.substr(0, 2);
		s2 = s4.substr(2, 2);
		s3 = s4.substr(4, 2);
		//cout << s1<<s2<<s3 << endl;
		for (int i = 6; i < s4.size(); i+=2)
		{
    
    
			s4 = s4.substr(i, 2);
			s1 = s2;
			s2 = s3;
			s3 = s4;
			mp[{
    
    s1, s2}][s3]++;
		}
	}
	inFile.close();
	for (auto& sssi : mp)
	{
    
    
		string s1 = sssi.first.first, s2 = sssi.first.second;
		vector<pair<string, int>> tv;
		for (auto& si : sssi.second)
			tv.push_back({
    
     si });

		//对出现概率进行排序
		sort(tv.begin(), tv.end(), [&](pair<string, int>& p1, pair<string, int>& p2){
    
    
			return p1.second > p2.second;
			});
		
		store_keyval[{
    
    s1, s2}] = tv[0].first;
	}
}

vector<string> dic = {
    
     "的","一","了","是","我","不","在","人","们","有" };

void test()
{
    
    
	srand((unsigned)time(NULL));
	int len = 300;
	//cout << "请输入续写的长度:" << endl;
	//cin >> len;
	cout << "请输入想要续写的内容" << endl;
	string s;
	cin >> s;
	//cout << s.size() << endl;
	string s1, s2;
	s1=s.substr(s.size() - 4, 2);
	s2 = s.substr(s.size() - 2, 2);
	for (int i = 0; i < len; i++)
	{
    
    
		string s3;
		if (store_keyval.find({
    
     s1, s2 }) == store_keyval.end())
		{
    
    
			int p = rand()%10;
			s3 = dic[p];
		}
		else
			s3 = store_keyval[ {
    
     s1, s2 }];
		cout << s3 << " ";
		s1 = s2;
		s2 = s3;
	}
	cout << endl;
}

int main()
{
    
    
	train();
	int test_time = 0;
	cout << "请输入需要询问的次数" << endl;
	cin >> test_time;
	while (test_time--)
	{
    
    
		test();
	}
}

训练效果

请添加图片描述

猜你喜欢

转载自blog.csdn.net/qq_68591679/article/details/131655369