因为预测的值由原来的四个值,变成现在的图像中心点的两个值,个人感觉可能会带来精度上的提升,附上中心点损失值的代码:
预测图像中点的MSE损失函数:
```python
import torch.nn as nn
class MSE_Loss(nn.Module):
def __init__(self):
super(MSE_Loss, self).__init__()
def forward(self, pred, target):
loss = nn.MSELoss()
return loss(pred, target)
```
在使用时,需要将预测值和真实值传入forward函数中即可。其中,预测值和真实值都是形状为(B,2)的tensor,B表示batch size,2表示坐标轴数目。
```python
loss_fn = MSE_Loss()
pred = model(image) # image为输入的图像
loss = loss_fn(pred, target) # target为真实的中点坐标
需要注意的是,模型传出的数据为两个值,分别代表图像中心的x和y
target也需要提前修改好,原来是中心点,加上长和宽总共四个值,改成中心点两个值就好了
目前正在测试,明天看看结果