Discuz! Board

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 60|回复: 0

CNN学习

[复制链接]

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
发表于 3 天前 | 显示全部楼层 |阅读模式
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F

  4. # 设置随机种子确保可重复性
  5. torch.manual_seed(42)

  6. class SimpleCNN(nn.Module):
  7.     def __init__(self):
  8.         super(SimpleCNN, self).__init__()
  9.         # 手动初始化卷积层参数
  10.         self.conv1 = nn.Conv2d(1, 2, kernel_size=2, stride=2, bias=True)
  11.         self.conv1.weight.data = torch.tensor([
  12.             [[[0.1, 0.2], [0.3, 0.4]]],  # 第一个滤波器
  13.             [[[0.5, 0.6], [0.7, 0.8]]]   # 第二个滤波器
  14.         ], dtype=torch.float32)
  15.         self.conv1.bias.data = torch.tensor([0.1, 0.2], dtype=torch.float32)
  16.         
  17.         # 手动初始化全连接层参数
  18.         self.fc = nn.Linear(2, 3, bias=True)
  19.         self.fc.weight.data = torch.tensor([[0.1, 0.2],
  20.                                            [0.3, 0.4],
  21.                                            [0.5, 0.6]], dtype=torch.float32)
  22.         self.fc.bias.data = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
  23.    
  24.     def forward(self, x):
  25.         # 卷积层
  26.         conv_out = self.conv1(x)
  27.         print(f"1. 卷积层输出: {conv_out.detach().squeeze().numpy().round(4)}")
  28.         
  29.         # ReLU激活
  30.         relu_out = F.relu(conv_out)
  31.         print(f"2. ReLU激活输出: {relu_out.detach().squeeze().numpy().round(4)}")
  32.         
  33.         # 展平特征图
  34.         flattened = relu_out.view(relu_out.size(0), -1)
  35.         
  36.         # 全连接层
  37.         fc_out = self.fc(flattened)
  38.         print(f"3. 全连接层输出: {fc_out.detach().squeeze().numpy().round(4)}")
  39.         
  40.         return fc_out

  41. # 创建模型
  42. model = SimpleCNN()

  43. # 输入数据 (batch_size=1, channels=1, height=2, width=2)
  44. x = torch.tensor([[[[1.0, 0.0],
  45.                   [0.0, 1.0]]]])

  46. # 真实标签 (类别索引为2)
  47. y = torch.tensor([1])

  48. for i in range(20):
  49. # 前向传播
  50.     print("================= 前向传播 =================!!!!!!!!!!!!!!!!!!!!")
  51.     output = model(x)

  52.     # 计算损失 (交叉熵损失)
  53.     criterion = nn.CrossEntropyLoss()
  54.     loss = criterion(output, y)

  55.     # 计算softmax概率
  56.     probs = F.softmax(output, dim=1)
  57.     print(f"4. Softmax概率: {probs.detach().squeeze().numpy().round(4)}")
  58.     print(f"5. 损失值: {loss.item():.4f}")

  59.     # 反向传播
  60.     print("\n================= 反向传播 =================")
  61.     model.zero_grad()
  62.     loss.backward()

  63.     # 打印梯度
  64.     print("6. 卷积层权重梯度:")
  65.     print(model.conv1.weight.grad.detach().squeeze().numpy().round(4))
  66.     print("卷积层偏置梯度:", model.conv1.bias.grad.detach().numpy().round(4))

  67.     print("\n7. 全连接层权重梯度:")
  68.     print(model.fc.weight.grad.detach().numpy().round(4))
  69.     print("全连接层偏置梯度:", model.fc.bias.grad.detach().numpy().round(4))

  70.     # 参数更新 (学习率0.1)
  71.     print("\n================= 参数更新 (学习率0.1) =================")
  72.     with torch.no_grad():
  73.         # 更新卷积层参数
  74.         model.conv1.weight -= 0.1 * model.conv1.weight.grad
  75.         model.conv1.bias -= 0.1 * model.conv1.bias.grad
  76.         
  77.         # 更新全连接层参数
  78.         model.fc.weight -= 0.1 * model.fc.weight.grad
  79.         model.fc.bias -= 0.1 * model.fc.bias.grad

  80.     print("8. 更新后的卷积层权重:")
  81.     print(model.conv1.weight.detach().squeeze().numpy().round(4))
  82.     print("更新后的卷积层偏置:", model.conv1.bias.detach().numpy().round(4))

  83.     print("\n9. 更新后的全连接层权重:")
  84.     print(model.fc.weight.detach().numpy().round(4))
  85.     print("更新后的全连接层偏置:", model.fc.bias.detach().numpy().round(4))

  86. print("================= 前向传播 =================")
  87. output = model(x)

复制代码
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|DiscuzX

GMT+8, 2025-6-8 03:20 , Processed in 0.033679 second(s), 19 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表