pytorch 自定义参数不更新

nn.Module中定义参数:不需要加cuda,可以求导,反向传播

class BiFPN(nn.Module):
    def __init__(self, fpn_sizes):

    self.w1 =  nn.Parameter(torch.rand(1))

    print("no---------------------------------------------------",self.w1.data, self.w1.grad)

下面这个例子说明中间变量可能没有梯度,但是最终变量有梯度:

cy1 cd都有梯度

import torch

xP=torch.Tensor([[ 3233.8557,  3239.0657,  3243.4355,  3234.4507,  3241.7087,
          3243.7292,  3234.6826,  3237.6609,  3249.7937,  3244.8623,
          3239.5349,  3241.4626,  3251.3457,  3247.4263,  3236.4924,
          3251.5735,  3246.4731,  3242.4692,  3239.4958,  3247.7283,
          3251.7134,  3249.0237,  3247.5637],
        [ 1619.9011,  1619.7140,  1620.4883,  1620.0642,  1620.2191,
          1619.9796,  1617.6597,  1621.1522,  1621.0869,  1620.9725,
          1620.7130,  1620.6071,  1620.7437,  1621.4825,  1620.5107,
          1621.1519,  1620.8462,  1620.5944,  1619.8038,  1621.3364,
          1620.7399,  1621.1178,  1618.7080],
        [ 1619.9330,  1619.8542,  1620.5176,  1620.1167,  1620.1577,
          1620.0579,  1617.7155,  1621.1718,  1621.1338,  1620.9572,
          1620.6288,  1620.6621,  1620.7074,  1621.5305,  1620.5656,
          1621.2281,  1620.8346,  1620.6021,  1619.8228,  1621.3936,
          1620.7616,  1621.1954,  1618.7983],
        [ 1922.6078,  1922.5680,  1923.1331,  1922.6604,  1922.9589,
          1922.8818,  1920.4602,  1923.8107,  1924.0142,  1923.6907,
          1923.4465,  1923.2820,  1923.5728,  1924.4071,  1922.8853,
          1924.1107,  1923.5465,  1923.5121,  1922.4673,  1924.1871,
          1923.6248,  1923.9086,  1921.9496],
        [ 1922.5948,  1922.5311,  1923.2850,  1922.6613,  1922.9734,
          1922.9271,  1920.5950,  1923.8757,  1924.0422,  1923.7318,
          1923.4889,  1923.3296,  1923.5752,  1924.4948,  1922.9866,
          1924.1642,  1923.6427,  1923.6067,  1922.5214,  1924.2761,
          1923.6636,  1923.9481,  1921.9005]])

yP=torch.Tensor([[ 2577.7729,  2590.9868,  2600.9712,  2579.0195,  2596.3684,
          2602.2771,  2584.0305,  2584.7749,  2615.4897,  2603.3164,
          2589.8406,  2595.3486,  2621.9116,  2608.2820,  2582.9534,
          2619.2073,  2607.1233,  2597.7888,  2591.5735,  2608.9060,
          2620.8992,  2613.3511,  2614.2195],
        [  673.7830,   693.8904,   709.2661,   675.4254,   702.4049,
           711.2085,   683.1571,   684.6160,   731.3878,   712.7546,
           692.3011,   701.0069,   740.6815,   720.4229,   681.8199,
           736.9869,   718.5508,   704.3666,   695.0511,   721.5912,
           739.6672,   728.0584,   729.3143],
        [  673.8367,   693.9529,   709.3196,   675.5266,   702.3820,
           711.2159,   683.2151,   684.6421,   731.5291,   712.6366,
           692.1913,   701.0057,   740.6229,   720.4082,   681.8656,
           737.0168,   718.4943,   704.2719,   695.0775,   721.5616,
           739.7233,   728.1235,   729.3387],
        [  872.9419,   891.7061,   905.8004,   874.6565,   899.2053,
           907.5082,   881.5528,   883.0028,   926.3083,   908.9742,
           890.0403,   897.8606,   934.6913,   916.0902,   880.4689,
           931.3562,   914.4233,   901.2154,   892.5759,   916.9590,
           933.9291,   923.0745,   924.4461],
        [  872.9661,   891.7683,   905.8128,   874.6301,   899.2887,
           907.5155,   881.6916,   883.0234,   926.3242,   908.9561,
           890.0731,   897.9221,   934.7324,   916.0806,   880.4300,
           931.3933,   914.5662,   901.2715,   892.5501,   916.9894,
           933.9813,   923.0823,   924.3654]])


shape=[4000, 6000]
cx,cy1=torch.rand(1,requires_grad=True),torch.rand(1,requires_grad=True)

cd=torch.rand(1,requires_grad=True)
ox,oy=cx,cy1
print('cx:{},cy:{}'.format(id(cx),id(cy1)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
cx,cy=cx*shape[1],cy1*shape[0]
print('cx:{},cy:{}'.format(id(cx),id(cy)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
distance=torch.sqrt(torch.pow((xP-cx),2)+torch.pow((yP-cy),2))
mean=torch.mean(distance,1)
starsFC=cd*torch.pow((distance-mean[...,None]),2)
loss=torch.sum(torch.mean(starsFC,1).squeeze(),0)
loss.backward()
print(loss)
print(cx)
print(cy1)
print("cx",cx.grad)
print("cy",cy1.grad)
print("cd",cd.grad)
print(ox.grad)
print(oy.grad)
print('cx:{},cy:{}'.format(id(cx),id(cy)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
发布了2608 篇原创文章 · 获赞 920 · 访问量 506万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/103672674