自定义的torch中与np.nanmean同功能的函数(只能做一维的)

def torch_nanmean(x):
    num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum()
    value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum()
    return value / num
发布了38 篇原创文章 · 获赞 98 · 访问量 36万+

猜你喜欢

转载自blog.csdn.net/xijuezhu8128/article/details/86590673
今日推荐