Discuz! Board

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 4023|回复: 14

python卷积网络进行识别,手写字符识别

[复制链接]

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
发表于 2023-9-8 12:01:58 | 显示全部楼层 |阅读模式
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torch.utils.data import DataLoader, Dataset
  5. from PIL import Image
  6. import os
  7. import torch.nn.functional as F
  8. from torchvision import transforms
  9. # 设置训练超参数
  10. batch_size = 32#每一轮输入图像的张数。
  11. learning_rate = 0.001#学习率,每一次梯度下降的程度,学习率太大容易找不到最优价。
  12. num_epochs = 10#最大epoch数量,这里训练10轮,


  13. # 定义模型类
  14. class DigitNet(nn.Module):#从torch的Module模块继承并构建新的网络DigitNet。
  15.     def __init__(self):
  16.         super().__init__()#继承Module的初始化方法。
  17.         self.conv1 = nn.Conv2d(1, 10, 5)#卷积层1,3个参数分别为输入通道数,输出通道数,卷积核大小
  18.         self.conv2 = nn.Conv2d(10, 20, 3)#卷积层2,同上
  19.         self.fc1 = nn.Linear(20 * 10 * 10, 500)#这里的20是conv2输出的通道数,10是由于输入图像大小为28,经过conv1大小变为24,通过max_pool2d大小减半变为12,再通过conv2变为10,因此大小是10*10,输出500维的特征
  20.         self.fc2 = nn.Linear(500, 16)#通过500维的特征输出16个分类

  21.     def forward(self, x):#前向传递过程
  22.         input_size = x.size(0)
  23.         x = self.conv1(x)#卷积层1
  24.         x = F.relu(x)#激活函数,不改变大小
  25.         x = F.max_pool2d(x, 2, 2)#池化,大小减半
  26.         x = self.conv2(x)#卷积层2
  27.         x = F.relu(x)#激活函数
  28.         x = x.view(input_size, -1)#展平操作,用于构建fc层输入
  29.         x = self.fc1(x)#全连接层1
  30.         x = F.relu(x)#激活函数
  31.         x = self.fc2(x)#全连接层2
  32.         output = F.log_softmax(x, dim=1)#softmax分类
  33.         return output#返回输出类别


  34. # 定义自定义数据集类
  35. class CustomDataset(Dataset):
  36.     def __init__(self, root_dir, transform=None):
  37.         self.root_dir = root_dir  # 根目录路径
  38.         self.classes = os.listdir(root_dir)  # 获取根目录下的所有类别(子文件夹)名字
  39.         self.data = []  # 存储数据文件路径
  40.         self.targets = []  # 存储数据对应的标签
  41.         self.transform = transform  # 数据预处理的转换操作

  42.         for i, class_name in enumerate(self.classes):
  43.             class_dir = os.path.join(self.root_dir, class_name)  # 每个类别的文件夹路径
  44.             file_names = os.listdir(class_dir)  # 获取当前类别文件夹下的所有文件名
  45.             for file_name in file_names:
  46.                 file_path = os.path.join(class_dir, file_name)  # 每个文件的完整路径
  47.                 self.data.append(file_path)  # 将文件路径添加到data列表中
  48.                 self.targets.append(i % 16)  # 将类别的索引添加到targets列表中,取余是为了使标签在0-15之间循环

  49.     def __len__(self):
  50.         return len(self.data)  # 返回数据集的样本数量

  51.     def __getitem__(self, idx):
  52.         image_path = self.data[idx]  # 获取指定索引处的图像路径
  53.         image = Image.open(image_path).convert('L')  # 使用PIL库打开图像,并将其转换为灰度图像

  54.         if self.transform is not None:
  55.             image = self.transform(image)  # 对图像进行预处理转换

  56.         target = self.targets[idx]  # 获取指定索引处的标签
  57.         return image, target  # 返回图像和对应的标签


  58. # 设置数据预处理和转换
  59. data_transform = transforms.Compose([
  60.     transforms.Resize((28, 28)),  # 调整图像大小为 28x28
  61.     transforms.ToTensor(),  # 将图像转换为张量
  62.     transforms.Normalize((0.5,), (0.5,))  # 归一化处理
  63. ])

  64. # 创建数据加载器
  65. dataset = CustomDataset('D:/data', transform=data_transform)  # 创建自定义数据集实例,并应用数据预处理的转换操作
  66. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)  # 创建数据加载器,指定批量大小和是否随机打乱数据

  67. # 创建模型实例
  68. model = DigitNet()  # 创建数字识别模型实例

  69. # 定义损失函数和优化器
  70. criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,用于多分类问题
  71. optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器,用于参数优化

  72. # 设置模型为训练模式
  73. model.train()

  74. # 开始训练
  75. for epoch in range(num_epochs):  # 遍历每个epoch
  76.     for images, labels in dataloader:  # 遍历每个batch的图像和标签
  77.         # 前向传播
  78.         outputs = model(images)  # 将图像输入模型,获取预测结果

  79.         # 计算损失
  80.         loss = criterion(outputs, labels)  # 计算预测结果与真实标签之间的损失

  81.         # 反向传播和优化
  82.         optimizer.zero_grad()  # 清空梯度
  83.         loss.backward()  # 反向传播计算梯度
  84.         optimizer.step()  # 更新模型参数

  85.     # 每个epoch结束后打印损失
  86.     print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

  87. # 保存模型
  88. torch.save(model.state_dict(), 'digit_model.pth')  # 保存模型参数到文件

复制代码
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2023-9-8 12:32:56 | 显示全部楼层
识别的代码
  1. def match_chars(arrX, model_path):
  2.     # 加载预训练模型
  3.     model = DigitNet()
  4.     model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  5.     model.eval()

  6.     # 设置图像预处理的转换
  7.     transform = transforms.Compose([
  8.         transforms.Resize((28, 28)),  # 调整图像大小为 28x28
  9.         transforms.ToTensor(),  # 将图像转换为张量
  10.         transforms.Normalize((0.5,), (0.5,))  # 归一化处理
  11.     ])

  12.     recognized_chars = []

  13.     # 循环遍历所有切割后的字符图像
  14.     for char_img in arrX:
  15.         char_img=padd(char_img)
  16.         plt.imshow(char_img, cmap='gray')
  17.         plt.show()
  18.         # 加载图像并进行预处理
  19.         image = Image.fromarray(char_img)
  20.         image = transform(image)  # 应用预处理转换
  21.         image = image.unsqueeze(0)  # 添加 batch 维度
  22.         # 进行字符识别
  23.         output = model(image)
  24.         probabilities = F.softmax(output, dim=1)
  25.         # 获取预测结果及对应的置信度
  26.         predicted_prob, predicted_label_idx = torch.max(probabilities, 1)
  27.         predicted_label = class_labels[predicted_label_idx.item()]
  28.         print(predicted_prob.item())
  29.         if predicted_prob.item()>0.9:
  30.             # 将预测结果添加到识别字符列表中
  31.             recognized_chars.append(str(predicted_label))
  32.         else:

  33.             recognized_chars.append(char_img)

  34.     return recognized_chars
复制代码
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2023-9-8 12:33:29 | 显示全部楼层
完整实例

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-14 16:48:09 | 显示全部楼层
pytorch】卷积操作原理解析与nn.Conv2d用法详解
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-14 16:48:18 | 显示全部楼层
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-14 17:54:58 | 显示全部楼层
import torch
import torch.nn as nn
# With square kernels and equal stride
m= nn.Conv2d(1, 1, 1, stride=1)
input = torch.randn(1,3, 3,4)
output = m(input)
print(input)
print(output)
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-15 09:21:25 | 显示全部楼层
import torch
import torch.nn as nn
# With square kernels and equal stride
m= nn.Conv2d(1, 1, 3, stride=1)
m.state_dict()['weight']=[[[[1, 2, 3],
[4, 5, 6],
[ 7,  8,  9]]]]
print(m.state_dict()['weight'])
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-15 09:22:50 | 显示全部楼层
conv_zeros.weight = torch.nn.Parameter(torch.ones(1,1,1,1))
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-15 09:24:09 | 显示全部楼层
PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和numpy是对应的,而且PyTorch的Tensor和numpy的ndarray可以相互转换,唯一不同的是PyTorch可以在GPU上运行,而numpy的ndarray只能在CPU上运行。
————————————————

                            版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
                        
原文链接:https://blog.csdn.net/vivi_cin/article/details/129052274
回复

使用道具 举报

399

主题

1251

帖子

4020

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
4020
 楼主| 发表于 2024-5-15 09:32:22 | 显示全部楼层
如何设置参数
import torch
import torch.nn as nn
# With square kernels and equal stride
m= nn.Conv2d(1, 1, 3, stride=1)
print(torch.ones(1,1,1,1))
m.weight = torch.nn.Parameter(torch.ones(1,1,1,1))
输入有4个参数,分别是batch channel  m n
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|DiscuzX

GMT+8, 2025-6-8 07:41 , Processed in 0.041239 second(s), 19 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表