[PyTorch]PyTorch中模型的参数初始化的⼏种⽅法(转)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
本⽂⽬录
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
转载请注明出处:
参考⽹址:
说明:暂时就这么多吧,错误之处请见谅。前两个初始化的⽅法见pytorch官⽅⽂档
1. xavier初始化
对于输⼊的tensor或者变量,通过论⽂Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的⽅法初始化数据。初始化服从均匀分布,其中,该初始化⽅法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))
对于输⼊的tensor或者变量,通过论⽂Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的⽅法初始化数据。初始化服从⾼斯分布,其中,该初始化⽅法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)
2. kaiming初始化
对于输⼊的tensor或者变量,通过论⽂“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的⽅法初始化数据。初始化服从均匀分布,其中,该初始化⽅法也称He initialisation。
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:该层后⾯⼀层的激活函数中负的斜率(默认为ReLU,此时a=0)
mode:‘fan_in’ (default) 或者 ‘fan_out’. 使⽤fan_in保持weights的⽅差在前向传播中不变;使⽤fan_out
保持weights的⽅差在反向传播中不变。
例如:
w = torch.Tensor(3, 5)
nn.init.kaiming_uniform(w, mode='fan_in')
对于输⼊的tensor或者变量,通过论⽂“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的⽅法初始化数据。初始化服从⾼斯分布,其中,该初始化⽅法也称He initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:该层后⾯⼀层的激活函数中负的斜率(默认为ReLU,此时a=0)
mode:‘fan_in’ (default) 或者 ‘fan_out’. 使⽤fan_in保持weights的⽅差在前向传播中不变;使⽤fan_out
保持weights的⽅差在反向传播中不变。
例如:
w = torch.Tensor(3, 5)
nn.init.kaiming_normal(w, mode='fan_out')
使⽤的例⼦(具体参见原始⽹址):
import init
self.classifier = nn.Linear(self.stages[3], nlabels)
init.kaiming_normal(self.classifier.weight)
for key in self.state_dict():
if key.split('.')[-1] == 'weight':
if'conv'in key:
init.kaiming_normal(self.state_dict()[key], mode='fan_out')
if'bn'in key:
self.state_dict()[key][...] = 1
elif key.split('.')[-1] == 'bias':
self.state_dict()[key][...] = 0
3. 实际使⽤中看到的初始化
3.1 ResNeXt,densenet中初始化
conv
n = kW* kH*nOutputPlane
weight:normal(0,math.sqrt(2/n))
bias:zero()
batchnorm
weight:fill(1)
bias:zero()weight的几种形式
bias:zero()
3.2 wide-residual-networks中初始化(MSRinit)
conv
n = kW* kH*nInputPlane
weight:normal(0,math.sqrt(2/n))
bias:zero()
linear
bias:zero()
posted @ 2018-12-08 17:11 阅读( ...) 评论( ...)