模型参数加载

一般模型参数加载的方式:

model.load_state_dict(torch.load(path))

torch.save(model.state_dict(), path)

然而,nn.Parameter(torch.ones(10))这样使用nn.Parameter进行初始化的,
load_state_dict()就失效了,可以这样:

print("Model loaded from {}".format(self.args.checkpoint_dir))
checkpoint = torch.load(self.args.checkpoint_dir + f'checkpoint_epoch_best.pth', map_location=f'cuda:{self.args.gpu}')
# print("checkpoint:",checkpoint)
pl = checkpoint['state_dict']['ctx']

# self.model.load_state_dict(checkpoint['state_dict'], strict=False)不用这个了
self.model.prompt_learner.set_pl(pl)

当然此时模型类在定义时需要写一个能将模型参数通过外部数据进行赋值的函数,通过该函数将params传进去即可。

class PromptLearner(nn.Module):
    def __init__(self,llm,tokenizer,embed_tokens):
        super().__init__()
		print("Initializing a generic context")
        ctx_vectors = torch.ones(1,n_ctx, ctx_dim, dtype=dtype)
        self.ctx = nn.Parameter(ctx_vectors)  # 以上述向量初始为可优化参数 to be optimized
    def set_pl(self,pl):
        self.ctx = nn.Parameter(pl)

    def forward(self):
        ctx = self.ctx
        print("ccctx:",ctx,ctx.size())

torch.argmax()

next_tokens = torch.argmax(next_tokens_scores, dim=-1)#返回指定维度最大值的序号

也就是0维竖着看,1维横着看

import torch

x = torch.randn(2, 4)
print(x)
'''
tensor([[ 1.2864, -0.5955,  1.5042,  0.5398],
        [-1.2048,  0.5106, -2.0288,  1.4782]])
'''

# y0表示矩阵dim=0维度上(每一列)张量最大值的索引
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([0, 1, 0, 1])
'''

# y1表示矩阵dim=1维度上(每一行)张量最大值的索引
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([2, 3])
'''
^ :代表非

^A-Za-z: 代表非字母

[^A-Za-z]+ :可连续多个非字母的字符

.strip() :去掉首位的空格

.lower():把大写字母全部统一成小写

相关推荐

  1. 模型参数

    2024-03-24 00:22:03       18 阅读
  2. SVM 保存和模型参数

    2024-03-24 00:22:03       33 阅读
  3. python SVM 保存和模型参数

    2024-03-24 00:22:03       39 阅读
  4. pytorch保存和模型以及如何load部分参数

    2024-03-24 00:22:03       21 阅读
  5. PyTorch:模型方法详解

    2024-03-24 00:22:03       35 阅读
  6. Keras预训练模型

    2024-03-24 00:22:03       40 阅读
  7. anylabeling 模型后出错

    2024-03-24 00:22:03       48 阅读
  8. 【R3F】11.模型

    2024-03-24 00:22:03       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-24 00:22:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-24 00:22:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-24 00:22:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-24 00:22:03       18 阅读

热门阅读

  1. oracle添加用户

    2024-03-24 00:22:03       22 阅读
  2. 第四章 可变参数模板

    2024-03-24 00:22:03       19 阅读
  3. SQL运维_Unix下MySQL-5.5.11配置文件示例

    2024-03-24 00:22:03       20 阅读
  4. TensorFlow的研究应用与开发~深度学习

    2024-03-24 00:22:03       19 阅读
  5. 桥接模式简介

    2024-03-24 00:22:03       25 阅读
  6. MyBatis Plus笔记

    2024-03-24 00:22:03       16 阅读
  7. ns3-dev报错:fatal error: numbers: No such file or directory

    2024-03-24 00:22:03       18 阅读
  8. oracle表备份及还原

    2024-03-24 00:22:03       16 阅读