原文:Saving and Loading Models

作者:Matthew Inkawhich

介绍一系列关于 PyTorch 模型保存与加载的应用场景,主要包括三个核心函数:

[1] - torch.save

保存序列化的对象(Serialized object)到磁盘.

其中,应用了 Python 的 pickle 包,进行序列化,可适用于模型Models,张量Tensors,以及各种类型的字典对象的序列化保存.

[2] - torch.load

采用 Python 的 pickle 的 unpickling 函数,对磁盘 pickled 的对象文件进行反序列化(deserialize),加载到内存.

[3] - torch.nn.Module.load_state_dict

采用序列化的 state_dict 加载模型参数(字典).

1. state_dict 介绍

PyTorch中,torch.nn.Module 模型中的可学习参数(learnable parameters)(如,weights 和 biases),包含在模型参数(model parameters)里(根据 model.parameters() 进行访问.)

state_dict可以简单的理解为 Python 的字典对象,其将每一层映射到其参数张量.

注,只有包含待学习参数的网络层,如卷积层,线性连接层等,会在模型的 state_dict 中有元素值.

优化器对象(Optimizer object,torch.optim) 也有 state_dict,其包含了优化器的状态信息,以及所使用的超参数.

由于 state_dict 对象时 Python 字典的形式,因此,便于保存,更新,修改与恢复,有利于 PyTorch 模型和优化器的模块化.

例如,Training a classifier tutorial 中所使用的简单模型的 state_dict

# 模型定义
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ModelNet(nn.Module):
    def __init__(self):
        super(ModelNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 模型初始化
model = ModelNet()

# 优化器初始化
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, 
          "\t", 
          model.state_dict()[param_tensor].size()
         )

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight      torch.Size([6, 3, 5, 5])
conv1.bias      torch.Size([6])
conv2.weight      torch.Size([16, 6, 5, 5])
conv2.bias      torch.Size([16])
fc1.weight      torch.Size([120, 400])
fc1.bias      torch.Size([120])
fc2.weight      torch.Size([84, 120])
fc2.bias      torch.Size([84])
fc3.weight      torch.Size([10, 84])
fc3.bias      torch.Size([10])

Optimizer's state_dict:
param_groups      [{'weight_decay': 0, 
                   'dampening': 0, 
                   'params': [140448775121872, 140448775121728,
                              140448775121584, 140448775121440,
                              140448775121296, 140448775121152,
                              140448775121008, 140448775120864,
                              140448775120720, 140448775120576],
                   'nesterov': False, 
                   'momentum': 0.9, 
                   'lr': 0.001}]
state      {}

2. 模型保存与加载

2.1 保存/加载 state_dict (推荐)

# 模型保存
torch.save(model.state_dict(), PATH)

# 模型加载
model = ModelNet(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

当保存模型,用于推断时,只有训练的模型可学习参数是有必要进行保存的.

采用 torch.save() 函数保存模型的 state_dict,对于应用时,模型恢复具有最好的灵活性,因此推荐采用该方式进行模型保存.

PyTorch 通用模型保存格式为 .pt.pth 文件扩展名形式.

需要注意的时,在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).

load_state_dict() 函数的输入是字典形式,而不是对象保存的文件路径.

也就是说,在将保存的模型文件送入 load_state_dict() 函数前,必须将保存的 state_dict 进行反序列化.

例如,不能直接应用 model.load_state_dict(PATH),而是,load_state_dict(torch.load(PATH)).

2.2 保存/加载全部模型信息

# 保存
torch.save(model, PATH)

# 加载
model = ModelNet(*args, **kwargs) # 必须预先定义过模型.
model = torch.load(PATH)
model.eval()

这种方式是最直观的语法,包含最少的代码. 其会采用 Python 的pickle 模块保存全部的模型模块.

这种方式的缺点在于,序列化的数据受限于在模型保存时所采用的特定的类和准确的路径结构(specific classes and the exact directory structure). 其原因是,because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. 因此,再加载后经过许多重构后,或在其它项目中使用时,可能会被打乱.

2.3 保存/加载 CHECKPOINT 用于推断或恢复训练

# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# 加载
model = ModelNet(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# 或
model.train()

当保存模型断点checkpoint,用于推断或恢复训练时,必须保存模型state_dict 之外的其它信息. 优化器(optimizer) state_dict 的保存也很重要,其包含了模型训练时的缓存(buffers) 和参数. 其它保存信息,如断点 epoch,最后一次记录的训练loss,额外的 torch.nn.Embedding 层,等等.

模型保存时,为了保存多种组成的信息,需要将其组织为字典,并采用 torch.save() 序列化字典.

PyTorch 保存断点checkpoints 的格式为 .tar文件扩展名格式.

模型加载时,首先初始化模型和优化器;然后采用 torch.load() 加载字典. 这里,可以利用 python 字典查询,方便地访问所保存的信息项.

在恢复训练时,需要调用 model.train() 以确保所有网络层处于训练模式.

2.4 多个模型保存到一个文件

# 保存
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

# 加载
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

当需要将多个 torch.nn.Modules 组成的模型进行保存时,如 GAN,sequence-to-sequence 模型,模型集成(ensemble of models),可以采用与保存断点checkpoints模型相同的方法.

换句话说,将每个模型的 state_dict 以及对应的优化器保存为一个字典.

还可以保存任何其它有助于恢复训练的信息项,只需简单的添加其到字典中.

2.5 模型 WarmStarting

Warmstarting Model Using Parameters from a Different Model

采用其它模型的参数对模型进行“热身”.

# 保存
torch.save(modelA.state_dict(), PATH)

# 加载
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

在迁移学习和复杂模型训练的应用场景中,通常会遇到部分加载模型参数的情况.

利用部分训练的模型参数,哪怕只是使用很少一部分,也有助于训练过程的 warmstart;相比于完全从头开始训练,会加速模型训练的收敛速度.

无论是加载模型参数的部分 state_dict (丢弃了某些 keys ),还是模型参数 state_dict 的 keys 比要加载的模型 state_dict 更多,都可以设置 load_state_dict() 中的参数 strict=False,以忽略不匹配的 keys 元素.

如果需要将一个网络层的参数加载到其它网络层,但是 keys 不匹配,可以修改要加载模型的 state_dict 中的参数 keys 的名字,以匹配模型参数的 keys 名字.

2.6 跨设备保存/加载模型

Saving & Loading Model Across Devices

2.6.1 GPU 保存/CPU 加载

# 保存
torch.save(model.state_dict(), PATH)

# 加载
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

当在 CPU 上加载在 GPU 上训练保存的模型时,只需将 torch.device('cpu') 参数传递到 torch.load() 函数的 map_location 参数中. 此时,保存的参数张量会根据 map_location 参数自动的重新映射到 CPU 设备上.

2.6.2 GPU 保存/GPU 加载

# 保存
torch.save(model.state_dict(), PATH)

# 加载
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# 确保对输入模型的任何输入张量,调用 input = input.to(device) 处理

当在 GPU 上加载在 GPU 上训练保存的模型时,只需采用 model.to(torch.device('cuda'))将初始化的 model 转换为 CUDA 优化的模型.

此外,确保采用 .to(torch.device('cuda'))函数,对模型所有的输入张量进行操作.

注,调用 my_tensor.to(device),会返回 my_tensor 在 GPU 上的新副本,不会重写 my_tensor 的值. 因此,需要手工重写张量:

my_tensor = my_tensor.to(torch.device('cuda'))

2.6.3 CPU 保存/GPU 加载

# 保存
torch.save(model.state_dict(), PATH)

# 加载
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # 设定 GPU 设备ID
model.to(device)
# 确保对输入模型的任何输入张量,调用 input = input.to(device) 处理

当在 GPU 上加载在 CPU 上训练保存的模型时,只需将 cuda:device_id 参数传递到 torch.load() 函数的 map_location 参数中. 即可将模型加载到指定 GPU 设备上.

然后,调用 model.to(torch.device('cuda')) 以将模型参数张量转换为 CUDA 张量.

最后,采用 .to(torch.device('cuda'))函数,对CUDA优化的模型中所有的输入张量进行操作.

注,调用 my_tensor.to(device),会返回 my_tensor 在 GPU 上的新副本,不会重写 my_tensor 的值. 因此,需要手工重写张量:

my_tensor = my_tensor.to(torch.device('cuda'))

2.7 保存 torch.nn.DataParallel 模型

# 保存
torch.save(model.module.state_dict(), PATH)

# 加载
# 

torch.nn.DataParallel 是 PyTorch 的并行化 GPU 的模型封装.

通常采用 model.module.state_dict() 保存 DataParallel模型. 可以灵活的用于模型的加载.

Last modification:December 13th, 2018 at 02:12 pm