有如下两个文本(为了排版,我把标签也放在一起了,数字表示标签)其中data.txt
是样本,label.txt
是标签。
data.txt label.txt
涤纶梭织染色布0
全涤布1
皮革服装2
涤纶梭织染色布0
全涤布1
皮革服装2
短毛绒3
皮革服装2
短毛绒3
仿棉绒4
小缸费5
皮革服装2
短毛绒3
短毛绒3
仿棉绒4
小缸费5
皮革服装2
短毛绒3
皮革服装2
短毛绒3
仿棉绒4
小缸费5
皮革服装2
短毛绒3
短毛绒3
在实际情况中,经常会碰到某些类别样样本仅仅只是出现了几次,但是仍旧占用了一个类别,这就导致训练过程结果很容易遭到这些类别的影响。下面就是去除这些样本的思路:
1.读取所有的标签保存到y_train
中;
2.利用np.unique(y_train)
查看样本中一共有多少类别;
class labels: [0 1 2 3 4 5]# 上面有6个类别
3.利用y_count = np.bincount(y_train)
查看每个类别对应的样本数
samples in each class: [2 2 7 8 3 3]# 表示第0个类别有2个样本,第1个类别有2个样本……
4.利用less_than_label=np.where(y_count<5)[0]
找到所有样本个数少于5的类别号
less than label: [0 1 4 5]# 对应0,1,4,5这4个类别的样本数少于5
5.按行读取data.txt
中的样本和label.txt
中的标签,对比第line_number
行的标签号是否在less_than_label
中;如果在则跳过这一行,如果不在分别写入x_train.txt
和y_train.txt
import pandas as pd
import numpy as np
label = pd.read_csv('./label.txt', names=['c1'])
y_train = np.array(label['c1'])
y_count = np.bincount(y_train)
less_than_label=np.where(y_count<5)[0]
print('original labes:',y_train)
print('class labels:',np.unique(y_train))
print('samples in each class:',y_count)
print('less than label:',less_than_label)
f = open('./x_train.txt', 'w', encoding='utf-8')
p = open('./y_train.txt', 'w', encoding='utf-8')
line_number=0
for line in open('./data.txt',encoding='utf-8'):
if y_train[line_number] in less_than_label:
line_number+=1
continue
line=line.strip('\n')
f.write(line+ '\n')
p.write(str(y_train[line_number]) + '\n')
line_number+=1
f.close()
p.close()
结果:
皮革服装2
皮革服装2
短毛绒3
皮革服装2
短毛绒3
皮革服装2
短毛绒3
短毛绒3
皮革服装2
短毛绒3
皮革服装2
短毛绒3
皮革服装2
短毛绒3
短毛绒3