导入必要的库
1 | import torch |
定义一个函数,用于在一定布局下展示图像及其对应的标签
1 | def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): |
定义一个函数,用于获取数据加载时使用的工作线程数
1 | def get_dataloader_workers(): |
定义一个函数,用于加载 Fashion-MNIST 数据集
1 | def load_data_fashion_mnist(batch_size, resize=None): |
定义模型
1 | # 网络的输入和输出维度 |
定义交叉熵损失函数
1 | def cross_entropy(y_hat, y): |
模型评价指标计算
1 | def accuracy(y_hat, y): |
定义训练一个 epoch 的函数
1 | def train_epoch(net, train_iter, loss, updater): |
定义训练函数
1 | def train(net, train_iter, test_iter, loss, num_epochs, updater): |
通过以上步骤,我们成功实现了使用 PyTorch 进行 Softmax 回归模型的训练和测试。 Softmax 回归是深度学习中的重要组成部分,能够有效地处理分类问题。
softmax的高级API实现
1 | import torch # 导入PyTorch库 |