torch自带插值函数

class Model(nn.Module):#2022.11.7修改前,这个Model能跑通#forMultivariate
        
    def __init__(self,configs,channel=96,ratio=1):#channel针对ili数据集应该改成36 channel=input_length
        super(Model, self).__init__()
 
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        self.union = 1*self.pred_len
        self.Linear = nn.Linear(self.seq_len,self.union)
        self.Linear_1 = nn.Linear(self.union, self.pred_len)

        self.fc1 = nn.Linear(self.union,2*self.union)
        self.fc2 = nn.Linear(self.union,2*self.union)
        self.fc3 = nn.Linear(2*self.union,1)
 

    def forward(self, x):
        x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)#forL
        b, c, l = x.size() # (B,C,L)
        
        
        y = self.Linear(x)#forL
        y_r = self.Linear_1(y) 
        y_shuffle = y[torch.randperm(y.size(0))]#将 y按照批次维度打乱顺序得到y_shuffle
        
        x = F.interpolate(x, size=self.union)#插值,mode选择默认
        # y = F.interpolate(y, size=self.seq_len)#插值,mode选择默认
        # #joint    
        # h1 = F.relu(self.Linear(x) + self.Linear(y))
        # pred_xy = self.Linear(h1)
        # # #marginal
        # h2 = F.relu(self.Linear(x) + self.Linear(y_shuffle))
        # pred_x_y = self.Linear(h2) 
        #joint    
        h1 = F.relu(self.fc1(x) + self.fc2(y))
        h2 = F.relu(self.fc1(x) + self.fc2(y))
        pred_xy = self.fc3(h1)
        # #marginal
        h2 = F.relu(self.fc1(x) + self.fc2(y_shuffle))
        pred_x_y = self.fc3(h2)

        return  y_r.permute(0,2,1),pred_xy,pred_x_y

为了让中间层特征与输入对齐,使用toch.nn.Functional.interpolate.

参考资料

torch.nn.interpolate—torch上采样和下采样操作_两只蜡笔的小新的博客-CSDN博客_torch 下采样

torch.nn.functional中的interpolate插值函数_大梦冲冲冲的博客-CSDN博客_torch 插值

pytorch torch.nn.functional实现插值和上采样_weixin_30905133的博客-CSDN博客

猜你喜欢

转载自blog.csdn.net/weixin_43332715/article/details/128439696