pytorch固定BN层参数[通俗易懂]

pytorch固定BN层参数[通俗易懂]背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。原因:未固定主分支BN层中的running_mean和running_var。解决方法:将需要固定的BN层状态设置为eval。问题示例:环境:torch:1.7.0#-*-coding:utf-8-*-importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.M

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]  # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假设每个epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 计算损失和参数更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

参考:

Cannot freeze batch normalization parameters

BatchNorm2d

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/183783.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • DropDownList的常用属性和事件「建议收藏」

    DropDownList的常用属性和事件「建议收藏」SelectedItem属性设置或获取下拉菜单的选中项,该属性的类型为System.Web.UI.WebControls.ListItem.所有列表控件(ListControl)中的项都是该类型,它

  • SRVCTL详解

    SRVCTL详解1.SRVCTL概述SRVCTL是ORACLERAC集群配置管理的工具,可以管理Database、Instance、ASM、Service、Listener和NodeApplication,N

  • 挖洞经验丨敏感信息泄露+IDOR+密码确认绕过=账户劫持

    挖洞经验丨敏感信息泄露+IDOR+密码确认绕过=账户劫持本文中涉及到的相关漏洞已报送厂商并得到修复,本文仅限技术研究与讨论,严禁用于非法用途,否则产生的一切后果自行承担。今天分享的这篇Writeup是作者在HackerOne上某个邀请测试项目的发现,目标网站存在不安全的访问控制措施,可以利用其导致的敏感信息泄露(auth_token)+密码重置限制绕过,以越权(IDOR)方式,实现网站任意账户劫持(Takeover)。整个测试过程是一次最基本…

  • 开始激活成功教程so文件_so文件格式怎么打开

    开始激活成功教程so文件_so文件格式怎么打开第一、利用IDA静态分析native函数1.isEquals函数分析函数指令代码:简单分析指令代码:1>、PUSH{r3-r7,lr}是保存r3,r4,r5,r6,r7,lr的值到内存的栈中;与之对应的是POP{r3-r7,pc}pc:程序寄存器,保留下一条CPU即将执行的指令lr:连接返回寄存器,保留函数返回后,下一条应执行的指令2>、调用strlen,malloc,st

  • 进程分析工具 process_grep查看进程

    进程分析工具 process_grep查看进程当进程卡住不动或者死锁时,pstack可以把当前进程的代码栈打出来,方便我们排查。用法非常简单,后面直接加进程号即可。如果是多线程的,则会打印每个线程的堆栈信息。manpstack可查看帮助[root@localhost~]#pstack7383Thread8(Thread0x7fcc0429c700(LWP7384)):#00x00007fcc0d322a82inpthread_cond_timedwait@@GLIBC_2.3.2()from/lib64/.

  • 图像数字化的两种方式是_图像是如何数字化的

    图像数字化的两种方式是_图像是如何数字化的将图片存储为数据有两种方案。其一为位图,也被称为光栅图。即是以自然的光学的眼光将图片看成在平面上密集排布的点的集合。每个点发出的光有独立的频率和强度,反映在视觉上,就是颜色和亮度。这些信息有不同的

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号