https://huggingface.co/docs/diffusers/using-diffusers/using_safetensors

.ckpt 和 .safetensors

[1] - .ckpt 文件是用 pickle 序列化的,可能包含恶意代码,不信任的模型来源,加载 .ckpt 可能存在安全问题.

https://docs.python.org/zh-cn/3/library/pickle.html

构建恶意的 pickle 数据来 在解封时执行任意代码 是可能的。 绝对不要对不信任来源的数据和可能被篡改过的数据进行解封。

处理不信任数据时,更安全的序列化格式如 json 可能更为适合。

[2] - .safetensors 文件是用 numpy 保存的,只包含张量数据,没有任何代码,加载 .safetensors 更安全和快速.

为了将 .ckpt 文件转换为 .safetensors 文件,需要先加载 .ckpt 中的数据,然后用 numpy 保存为 .safetensors.

safetensors

doc: https://huggingface.co/docs/safetensors/index

模型加载

import torch
import safetensors.torch
import os


def load_state_dict(ckpt_path, location='cpu'):
    _, extension = os.path.splitext(ckpt_path)
    if extension.lower() == ".safetensors":
        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
    else:
        state_dict = get_state_dict(torch.load(
            ckpt_path, map_location=torch.device(location)))
    state_dict = get_state_dict(state_dict)
    print(f'Loaded state_dict from [{ckpt_path}]')
    return state_dict


def get_state_dict(d):
    return d.get('state_dict', d)
Last modification:May 11th, 2023 at 11:42 am