class CBAM_LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, seq_length):
super(CBAM_LSTM, self).__init__()
self.cbam = CBAM(input_dim)
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.seq_length = seq_length
def forward(self, x):
batch_size = x.size(0)
# CBAM部分
cbam_out = []
for i in range(self.seq_length):
xi = x[:, i, :].unsqueeze(1)
cbam_out.append(self.cbam(xi).squeeze(1))
cbam_out = torch.stack(cbam_out, dim=1)
# LSTM部分
lstm_out, _ = self.lstm(cbam_out)
lstm_out = lstm_out[:, -1, :]
# 全连接层
out = self.fc(lstm_out)
return out
# 模型实例化
model = CBAM_LSTM(input_dim=2, hidden_dim=16, output_dim=3, seq_length=SEQ_LENGTH) # 假设预警级别有3类
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)