# pytorch-test **Repository Path**: chen934298133/pytorch-test ## Basic Information - **Project Name**: pytorch-test - **Description**: pytorch-test - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-02-01 - **Last Updated**: 2026-02-01 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # PyTorch MNIST 手写数字识别项目 > `https://gitee.com/kongfanhe/pytorch-tutorial.git` 这是一个使用 PyTorch 实现的简单 MNIST 手写数字识别项目。项目使用全连接神经网络对 0-9 的手写数字进行分类。 ## 项目结构 ``` pytorch-tutorial/ ├── main.py # 主程序入口 ├── models/ # 模型定义模块 │ ├── __init__.py │ └── model.py # 神经网络模型定义 ├── data/ # 数据处理模块 │ ├── __init__.py │ └── data_loader.py # 数据加载器 ├── training/ # 训练相关模块 │ ├── __init__.py │ ├── train.py # 训练逻辑 │ └── evaluate.py # 模型评估 ├── utils/ # 工具函数模块 │ ├── __init__.py │ └── model_utils.py # 模型保存/加载、配置等 ├── visualization/ # 可视化模块 │ ├── __init__.py │ └── visualize.py # 可视化预测结果 ├── test.py # 原始单文件版本(已废弃,可删除) ├── model.pth # 训练完成后保存的模型文件(训练后自动生成) └── README.md # 本文件 ``` ## 环境要求 - Python 3.6+ - PyTorch 1.0+ - torchvision - matplotlib - numpy ## 安装依赖 ```bash pip install torch torchvision matplotlib numpy ``` 或者使用 requirements.txt(如果存在): ```bash pip install -r requirements.txt ``` ## 如何开始训练 ### 运行主程序 直接运行 `main.py`: ```bash python main.py ``` 程序会自动执行以下步骤: 1. 下载 MNIST 数据集(首次运行需要网络连接) 2. 创建神经网络模型 3. 训练模型 2 个 epoch 4. 评估模型准确率 5. 保存训练好的模型到 `model.pth` 6. 可视化前 4 个测试样本的预测结果 ### 训练过程说明 训练过程中会输出以下信息: - **初始准确率**: 模型未训练前的准确率(通常约 10%,相当于随机猜测) - **Epoch 0 准确率**: 第一个训练周期后的准确率 - **Epoch 1 准确率**: 第二个训练周期后的准确率 - **模型已保存到: model.pth**: 训练完成后的保存提示 ### 自定义训练参数 如果需要修改训练参数,可以编辑以下文件: - **训练轮数**: 修改 `main.py` 中 `train()` 函数的 `num_epochs=2` 参数 - **学习率**: 修改 `main.py` 中 `train()` 函数的 `lr=0.001` 参数 - **批次大小**: 修改 `data/data_loader.py` 中 `get_data_loader()` 函数的 `batch_size=15` 参数 - **优化器**: 可以修改 `training/train.py` 中的优化器类型 ## 训练完的模型在哪里? 训练完成后,模型会自动保存到项目根目录下的 `model.pth` 文件中。 **模型文件位置**: `./model.pth`(项目根目录) 该文件包含了训练好的神经网络的所有权重参数,可以用于后续的推理和部署。 ## 如何加载训练完的模型 ### 方法一:使用项目提供的加载函数 在代码中使用 `load_model` 函数加载模型: ```python from models import Net from utils import load_model # 创建模型实例(必须与训练时的模型结构相同) net = Net() # 加载训练好的权重 net = load_model(net, filepath="model.pth") # 现在可以使用模型进行预测 # net.eval() # 设置为评估模式 # with torch.no_grad(): # output = net(input_tensor) ``` ### 方法二:直接使用 PyTorch 的加载方法 ```python import torch from models import Net # 创建模型实例 net = Net() # 加载模型权重 net.load_state_dict(torch.load("model.pth")) # 设置为评估模式 net.eval() # 使用模型进行预测 # with torch.no_grad(): # output = net(input_tensor) ``` ### 完整示例:加载模型并进行预测 创建一个新文件 `predict.py` 来演示如何加载和使用模型: ```python import torch from torchvision import transforms from torchvision.datasets import MNIST from models import Net from utils import load_model # 加载模型 net = Net() net = load_model(net, "model.pth") # 加载测试数据 transform = transforms.Compose([transforms.ToTensor()]) test_dataset = MNIST("", train=False, transform=transform, download=True) test_image, true_label = test_dataset[0] # 进行预测 net.eval() with torch.no_grad(): # 将图像展平并添加批次维度 input_tensor = test_image.view(1, -1) output = net(input_tensor) predicted = torch.argmax(output, dim=1) print(f"真实标签: {true_label}") print(f"预测结果: {predicted.item()}") ``` ## 模型架构 本项目使用的神经网络结构如下: - **输入层**: 784 个神经元(28×28 像素的图像展平) - **隐藏层1**: 64 个神经元,ReLU 激活函数 - **隐藏层2**: 64 个神经元,ReLU 激活函数 - **隐藏层3**: 64 个神经元,ReLU 激活函数 - **输出层**: 10 个神经元(对应 0-9 十个数字类别),LogSoftmax 激活函数 ## 模块说明 - **main.py**: 主程序入口,协调各个模块完成训练和测试流程 - **models/**: 模型定义模块 - `model.py`: 定义神经网络模型结构(Net类) - **data/**: 数据处理模块 - `data_loader.py`: 处理数据加载,提供 `get_data_loader()` 函数 - **training/**: 训练相关模块 - `train.py`: 包含训练逻辑,提供 `train()` 和 `train_epoch()` 函数 - `evaluate.py`: 模型评估功能,提供 `evaluate()` 函数 - **visualization/**: 可视化模块 - `visualize.py`: 可视化预测结果,提供 `visualize_predictions()` 函数 - **utils/**: 工具函数模块 - `model_utils.py`: 工具函数,包含模型保存/加载和配置设置 ## 注意事项 1. **首次运行**: 程序会自动下载 MNIST 数据集,请确保网络连接正常 2. **数据集位置**: MNIST 数据集会下载到当前目录,占用约 60MB 空间 3. **模型文件**: `model.pth` 文件包含训练好的权重,可以备份保存 4. **GPU 支持**: 如果安装了 CUDA 版本的 PyTorch,代码会自动使用 GPU 加速(需要手动修改代码添加 `.to(device)`) 5. **SSL证书**: 代码已处理SSL证书验证问题,支持在macOS等系统上下载数据集 6. **中文字体**: 已配置matplotlib支持中文显示,避免字体缺失警告 ## 常见问题 **Q: 如何提高模型准确率?** A: 可以尝试增加训练轮数(epoch)、调整学习率、增加隐藏层神经元数量、使用更复杂的网络结构等。 **Q: 模型文件可以删除吗?** A: 可以,但删除后需要重新训练才能使用。建议保留 `model.pth` 文件以便后续使用。 **Q: 如何在不同机器上使用训练好的模型?** A: 将 `model.pth` 文件复制到目标机器,确保 PyTorch 版本兼容,然后使用加载函数加载即可。 **Q: 训练时间需要多久?** A: 在普通 CPU 上,2 个 epoch 的训练大约需要几分钟。使用 GPU 会显著加快训练速度。