# 机器学习 **Repository Path**: i7606/machine-learning ## Basic Information - **Project Name**: 机器学习 - **Description**: 个人机器学习的仓库,包含案例和代码、笔记 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-01-11 - **Last Updated**: 2026-01-26 ## Categories & Tags **Categories**: Uncategorized **Tags**: Java, Swing ## README # 机器学习笔记 ## 一、介绍 这是个人机器学习的仓库,包含案例和代码、笔记、数据集等,通过学习AI入门教程视频:[【AI入门】【Pytorch入门】00.手写神经网络框架(计算图)](https://www.bilibili.com/video/BV15bBVBsEV4), 独自实现了一个仿PyTorch的自动微分系统,计算图生成器,还有一个简单的神经网络框架,并根据此神经网络框架实现了几个简单的学习案例,以帮助学习和理解机器学习原理。 ## 二、案例介绍 所有案例在example文件夹中可以找到 ### 1、计算图生成案例 ComputationalGraphTest.java 文件为计算图测试案例,该案例使用java模拟前向传播和反向传播,并且绘制计算图: ![计算图测试代码](./images/ComputationalGraphTest.png "计算图测试代码") 运行上述代码后会自动在浏览器打开生成的计算图,如下图所示: ![计算图测试代码](./images/CG.png "计算图测试代码") ### 2、神经网络框架使用案例 FrameworkTest.java 文件为仿pytorch的神经网络框架测试案例,可以使用此代码搭建一个感知机并设置激活函数和损失函数,可以查看其计算图,橙色节点为输出节点,绿色节点为标签节点 ![框架测试代码](./images/FrameworkTest.png "框架测试代码") 运行上述代码后会自动在浏览器打开生成的计算图,如下图所示: ![框架计算图测试代码](./images/CG2.png "框架计算图测试代码") ![框架计算图测试代码](./images/CG3.png "框架计算图测试代码") ### 3、细菌觅食案例 BacteriaForaging.java 文件是使用自定义神经网络框架以及自定义数据集编写的一个细菌觅食的案例,此案例模拟细菌觅食,在一个九宫格内,细菌在中心,周围8个格子随机分布食物,训练神经网络朝具有食物的格子方向移动。 计算图如下: ![框架计算图测试代码](./images/CG4.png "框架计算图测试代码") BacteriaForagingGUI.java 文件则是该案例的可视化体验页面,点击周围8个格子添加或清除食物,点击中间格子预测细菌的行动方向。 ![细菌觅食GUI](./images/GUI.png "细菌觅食GUI")
点击展开细菌觅食训练代码 ~~~java /** * 细菌觅食神经网络案例 */ public class BacteriaForaging { public static void main(String[] args) { // 1、创建网络 NeuronNet net = new NeuronNet(); net.setInputSize(9); net.setLearningRate(0.01); net.addLayer(5, ActivateEnum.SIGMOID) .addLayer(4, ActivateEnum.SOFTMAX, LossEnum.CrossEntropyLoss); // 2、加载数据 Dataset dataSet = new Dataset(); ArrayList trainDataSetItem = dataSet.getTrainDataSet(); // 3、训练,训练多次训练看loss是否下降 double[] loss = new double[4]; final int epoch = 300; for (int i = 0; i < epoch; i++) { for (int index = 0; index < trainDataSetItem.size(); index++) { DataItem dataItem = trainDataSetItem.get(index); loss[(int) dataItem.getLabel()] += net.training(dataItem.getValueData(), dataItem.getOneHotValueData()); } for (int j = 0; j < 4; j++) { loss[j] = loss[j] / 5; } System.out.println("Iter " + i + " Loss: " + Arrays.toString(loss)); } // 4、测试 int countAll = 0; int countCorrect = 0; ArrayList testDataSetItem = dataSet.getTestDataSet(); for (int i = 0; i < testDataSetItem.size(); i++) { DataItem dataItem = testDataSetItem.get(i); List predict = net.predict(dataItem.getValueData()); ++countAll; int pLabel = ExampleUtils.maxIndex(predict); if (pLabel == (int) dataItem.getLabel()) { ++countCorrect; } } // 5、输出训练和测试信息 double accuracyRate = ((double) countCorrect / (double) countAll) * 100; String format = String.format("共训练:%d轮\n测试数据量:%d\n正确数:%d\n正确率:%f%%", epoch, countAll, countCorrect, accuracyRate); System.out.println(format); // 6、生成计算图 List topo = GenreatorTopo.getTopo(); NNCGGenerator.generatorGraph(topo); } } ~~~
### 4、手写数字识别案例 Mnist.java 文件是使用自定义神经网络以及mnist-dataset-jpg数据集编写的手写数字识别案例,因为训练时间过长,所以此处没有计算图和GUI。 > 如果需要运行此案例,需要在README文件夹下解压mnist-dataset-jpg.7z文件,一个小时左右才跑完一轮,可能得三天时间才能跑完。 ## 三、软件架构 通过Value计算和存储节点,计算后,将计算的数值填充到新的Value的children,用于跟踪计算过程,整个构建过程类似于生成多叉树. NnGraphGenerator加载resources/template.html文件,然后将Topo数组填充到页面模板中, 待页面运行后,自动加载填充好的数据,以达到生成计算图的效果. 1. NnComputationalGraphTest.java: 测试文件 2. Value.java: 计算单元,包含id、label、data、grad、operator 3. NnGroup.java 仿元组,给Value用于存储子节点 4. NnGraphGenerator.java: 负责将节点生成计算图 5. GenreatorTopo.java: 负责生成和存储topo、visited 使用以下软件和依赖: java: 1. JDK17及以上版本 2. fastjson2 html/js: 1. cytoscape 2. graphlib 3. dagre 4. cytoscape-dagre