[关闭]
@chenyaofo 2022-10-24T04:39:55.000000Z 字数 2430 阅读 298

PyTorch 存储模型


如何保存符合平台规范的模型描述文件?

PyTorch中常用的保存模型方式一般仅仅保存模型的权重(state_dicts),而在本平台中,通过平台的“模型服务”中的“新建模型”上传的模型(称为“模型描述文件”)需要同时包含模型的权重和模型的计算图(即模型描述代码)。

这里将以一个自定义模型为例,展示如何保存符合平台规范的模型描述文件。

首先先安装dill库,安装方式是pip install dill,其可以在序列化对象实例的同时也序列化了对象描述代码。

保存模型描述文件的示例代码如下:

  1. import dill
  2. import torch
  3. import torch.nn as nn
  4. class MyModel(nn.Module):
  5. def __init__(self, num_classes=10):
  6. super().__init__()
  7. self.feature_extractor = nn.Sequential(
  8. nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3),
  9. nn.BatchNorm2d(num_features=64),
  10. nn.ReLU(),
  11. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
  12. nn.BatchNorm2d(num_features=64),
  13. nn.ReLU(),
  14. )
  15. self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
  16. self.classifier = nn.Linear(in_features=64, out_features=num_classes)
  17. def forward(self, x):
  18. x = self.feature_extractor(x)
  19. x = self.global_pooling(x).squeeze(3).squeeze(2)
  20. x = self.classifier(x)
  21. return x
  22. model = MyModel()
  23. torch.save(model, "mymodel.pt", pickle_module=dill)
  24. load_model=torch.load("mymodel.pt", pickle_module=dill)
  25. state_dict = model.state_dict()
  26. load_state_dict = load_model.state_dict()
  27. for k,v in state_dict.items():
  28. print(f"In Layer {k}, the max difference is "
  29. f"{torch.max(torch.abs(state_dict[k]-load_state_dict[k])).item()}")

示例中值得注意的是保存模型描述文件时代码为torch.save(model, "mymodel.pt", pickle_module=dill)传入的第一个参数是模型实例本身,而不是模型的model.state_dict();此外,还需要传入自定义的pickle模块pickle_module=dill,这样才能序列化模型的计算图(即模型描述代码)。

运行上述代码,输出为:

  1. In Layer feature_extractor.0.weight, the max difference is 0.0
  2. In Layer feature_extractor.0.bias, the max difference is 0.0
  3. In Layer feature_extractor.1.weight, the max difference is 0.0
  4. In Layer feature_extractor.1.bias, the max difference is 0.0
  5. In Layer feature_extractor.1.running_mean, the max difference is 0.0
  6. In Layer feature_extractor.1.running_var, the max difference is 0.0
  7. In Layer feature_extractor.1.num_batches_tracked, the max difference is 0
  8. In Layer feature_extractor.3.weight, the max difference is 0.0
  9. In Layer feature_extractor.3.bias, the max difference is 0.0
  10. In Layer feature_extractor.4.weight, the max difference is 0.0
  11. In Layer feature_extractor.4.bias, the max difference is 0.0
  12. In Layer feature_extractor.4.running_mean, the max difference is 0.0
  13. In Layer feature_extractor.4.running_var, the max difference is 0.0
  14. In Layer feature_extractor.4.num_batches_tracked, the max difference is 0
  15. In Layer classifier.weight, the max difference is 0.0
  16. In Layer classifier.bias, the max difference is 0.0

输出结果证明了序列化之后的模型和序列化之前的模型在权重上无任何差异。

在保存好模型描述文件后(即上述例子中的mymodel.pt),即可通过平台的“模型服务”中的“新建模型”上传后使用了。

已知Bug:

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注