PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
目录
基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
相关文章
PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
PT之DNN:基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例实现代码
基于泰坦尼克号数据集(独热编码/标签编码)利用PyTorch框架的浅层神经网络算法(pth和onnx文件的模型导出和载入推理)实现二分类预测应用案例
# 1、定义数据集
D:\ProgramData\Anaconda3\python.exe E:/File_Python/Python_daydayup/20230512.py
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None
PassengerId Survived Pclass ... Fare Cabin Embarked
0 1 0 3 ... 7.2500 NaN S
1 2 1 1 ... 71.2833 C85 C
2 3 1 3 ... 7.9250 NaN S
3 4 1 1 ... 53.1000 C123 S
4 5 0 3 ... 8.0500 NaN S
[5 rows x 12 columns]
# 定义入模特征
after featuresIN………………………………………………
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Survived 891 non-null int64
1 Pclass 891 non-null int64
2 Age 714 non-null float64
3 Fare 891 non-null float64
4 Sex 891 non-null object
5 Embarked 889 non-null object
dtypes: float64(2), int64(2), object(2)
memory usage: 41.9+ KB
None
# 2、数据预处理
# 2.1、缺失值处理
# 2.2、特征编码
# T1、独热编码
# T2、标签编码
LBEncode………………………………………………
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Survived 891 non-null int64
1 Pclass 891 non-null int64
2 Age 891 non-null float64
3 Fare 891 non-null float64
4 Sex 891 non-null int32
5 Embarked 891 non-null int32
dtypes: float64(2), int32(2), int64(2)
memory usage: 34.9 KB
None
# 2.3、分离特征与标签
# 3、模型训练与评估
# 3.1、切分数据集
# 转换数据集为PyTorch的Tensor格式
X_train_tensor tensor([[ 1.0000, 16.0000, 57.9792, 0.0000, 0.0000],
[ 3.0000, 31.0000, 7.7500, 1.0000, 1.0000],
[ 3.0000, 45.5000, 7.2250, 1.0000, 0.0000],
...,
[ 3.0000, 32.0000, 7.9250, 1.0000, 2.0000],
[ 3.0000, 30.0000, 7.2500, 1.0000, 2.0000],
[ 3.0000, 29.0000, 7.7500, 1.0000, 1.0000]])
# 3.2、定义模型:前馈神经网络
# 初始化模型
# 定义损失函数和优化器
# 3.3、训练模型(前向传播+反向优化)
# 前向传播
# 反向传播和优化
# 每训练10个epoch,输出一次日志
Epoch [10/100], Loss: 1.3554
Epoch [20/100], Loss: 0.6037
Epoch [30/100], Loss: 0.5943
Epoch [40/100], Loss: 0.5738
Epoch [50/100], Loss: 0.5569
Epoch [60/100], Loss: 0.5480
Epoch [70/100], Loss: 0.5396
Epoch [80/100], Loss: 0.5304
Epoch [90/100], Loss: 0.5220
Epoch [100/100], Loss: 0.5146
# 3.4、模型评估并输出预测结果
# 模型评估
AUC Score: 0.8402
F1 Score: 0.6557
# 3.5、模型导出与推理
# T1、导出+载入pth模型文件进行推理
AUC Score: 0.8760
F1 Score: 0.7087
model_pth -------------------
Accuracy on validation set: 0.7933
PassengerId Survived loaded_model_pth_y_prob
0 172 0.809788 0.809788
1 524 0.119221 0.119221
2 452 0.302843 0.302843
3 170 0.295056 0.295056
4 620 0.184165 0.184165
.. ... ... ...
174 388 0.106167 0.106167
175 338 0.071640 0.071640
176 827 0.366535 0.366535
177 773 0.119190 0.119190
178 221 0.157593 0.157593
[179 rows x 3 columns]
# T2、导出+载入ONNX模型
AUC Score: 0.8402
F1 Score: 0.6557
model_ONNX -------------------
dummy_input: torch.Size([179, 5])
============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
X_val_tensor: (179, 5)
PassengerId Survived loaded_model_ONNX_y_prob
0 172 0.549517 0.549517
1 524 0.181606 0.181606
2 452 0.383283 0.383283
3 170 0.350839 0.350839
4 620 0.299215 0.299215
.. ... ... ...
174 388 0.155465 0.155465
175 338 0.082411 0.082410
176 827 0.301790 0.301790
177 773 0.181496 0.181496
178 221 0.243588 0.243588
[179 rows x 3 columns]