Transformer - Positional Encoding 位置编码 代码实现

Transformer - Positional Encoding 位置编码 代码实现

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

# 词嵌⼊维度是64维
d_model = 64
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=60

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)

绘图

import numpy as np
import matplotlib.pyplot as plt
# 创建⼀张15 x 5⼤⼩的画布
plt.figure(figsize=(15, 5))

pe = PositionalEncoding(d_model, 0, max_len)

y = pe(torch.zeros(1, max_len, d_model))


# 只查看3,4,5,6维的值.
plt.plot(np.arange(max_len), y[0, :, 3:7].data.numpy())

plt.legend(["dim %d"%p for p in [3,4,5,6]])

在这里插入图片描述

register_buffer 的测试

# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
       
        
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)

print(torch.__version__)

root="mydir/"

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)


criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

epochs = 1
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0

    net.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
        
        optimizer.zero_grad()
 
        out = net(images)
      
        loss = criterion(out, labels)
       
        train_loss += loss.item()
        train_acc += (out.max(1)[1] == labels).sum().item()
      
        loss.backward()
    
        optimizer.step()
    
        avg_train_loss = train_loss / len(train_loader.dataset)
        avg_train_acc = train_acc / len(train_loader.dataset)

    net.eval()
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
            out = net(images)
            loss = criterion(out, labels)
            val_loss += loss.item()
            acc = (out.max(1)[1] == labels).sum()
            val_acc += acc.item()
    avg_val_loss = val_loss / len(test_loader.dataset)
    avg_val_acc = val_acc / len(test_loader.dataset)
    print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
                   .format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
    
    
    


dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)

model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)
# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)


    

dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)


model_save_path = os.path.join(dir_name, "model.pt")



model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)

从模型加载的pe值,从未改变

tensor([[0.0566, 0.8944, 0.0873]])
tensor([[ 0.2529,  0.5227, -0.2610]])
tensor([[ 0.4632, -0.2602, -1.0032]])
tensor([[-0.3486,  0.8183, -1.3838]])
tensor([[ 0.7163,  0.5574, -0.0848]])
tensor([[-0.3415, -0.9013, -1.6136]])
tensor([[ 0.5490,  1.7691, -1.1375]])
tensor([[-0.3486,  0.8183, -1.3838]])

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-04-03 08:26:02       14 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-03 08:26:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-03 08:26:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-03 08:26:02       18 阅读

热门阅读

  1. RCE(远程命令执行)漏洞详解

    2024-04-03 08:26:02       14 阅读
  2. npm常用命令详解

    2024-04-03 08:26:02       9 阅读
  3. Github 2024-04-02开源项目日报Top10

    2024-04-03 08:26:02       13 阅读
  4. ElasticSearch的常用数据类型

    2024-04-03 08:26:02       10 阅读
  5. Hbase

    Hbase

    2024-04-03 08:26:02      9 阅读
  6. qt MVC软件设计模式

    2024-04-03 08:26:02       13 阅读
  7. ubuntu同步网络时间

    2024-04-03 08:26:02       13 阅读
  8. openGauss 工具链_DataKit

    2024-04-03 08:26:02       17 阅读
  9. HarmonyOs开发之———容器组件使用

    2024-04-03 08:26:02       16 阅读