用cityscapes fine tune yolov8-seg

cityscapes数据集预处理

import os
import random
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


def get_subfolders_with_path(folder_path):
	subfolders = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]
	subfolders.sort()
	return subfolders


def get_files_in_folder(folder_path):
	files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
	files.sort()
	return files


class CustomDataset(Dataset):
	def __init__(self):
		# 指定文件夹路径
		folder_path = '/home/Downloads/cityspaces/leftImg8bit/train/'
		
		# 获取带路径的文件夹列表
		folder_list_with_path = get_subfolders_with_path(folder_path)
		
		self.all_images_path = []
		# 获取文件列表
		for folder_i in folder_list_with_path:
			file_list = get_files_in_folder(folder_i)
			self.all_images_path.extend(file_list)
	
	def __len__(self):
		return len(self.all_images_path)
	
	def __getitem__(self, item):
		name_i = self.all_images_path[item]
		my_string = name_i[:-4]
		# 找到第一个 "leftImg8bit" 的索引位置
		first_index = my_string.find("leftImg8bit")
		# 找到第二个 "leftImg8bit" 的索引位置,从第一个之后开始搜索
		second_index = my_string.find("leftImg8bit", first_index + 1)
		# 使用切片和 replace() 方法替换第二个 "leftImg8bit" 为 "gt_Fine_labelids.png"
		new_string = my_string[:second_index] + "gtFine_labelIds.png" + my_string[second_index + len("leftImg8bit"):]
		# 替换第一个leftImg8bit
		label_path = new_string.replace("leftImg8bit", "gtFine", 1)
		
		image = cv2.imread(name_i)
		image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
		
		# plt.imshow(image)
		# plt.show()
  
		label = cv2.imread(label_path, 0)
		
		input_width = 512
		input_height = 512
		image = cv2.resize(image, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
		mean = np.array([0.485, 0.456, 0.406]) * 255
		std = np.array([0.229, 0.224, 0.225]) * 255
		# 对图像进行归一化
		image = (image - mean) / std
		
		# 0-33 34个标签  yoloV8-seg输出0-31 32个标签
		label[label > 31] = 31
		label[label == 1] = 7
		label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)

		# plt.subplot(121)
		# plt.imshow(image)
		# plt.subplot(122)
		# plt.imshow(label)
		# plt.show()
		
		label = torch.from_numpy(label)
		
		return image, label
	

def get_dataloader():
	batch_size = 8
	train_set = CustomDataset()
	train_loader_ = DataLoader(train_set, batch_size=batch_size, shuffle=False, drop_last=False)
	
	return train_loader_


if __name__ == "__main__":
	train_loader = get_dataloader()
	for batch_idx, (data, target) in enumerate(train_loader):
		print(batch_idx)

加载yolo8-seg模型

# Load YOLOv8n-seg, train it on COCO128-seg for 3 epochs and predict an image with it
from ultralytics import YOLO
import matplotlib.pyplot as plt


model = YOLO('yolov8n-seg.pt')  # load a pretrained YOLOv8n segmentation model
# Train the model
# results = model.train(data='coco128-seg.yaml', epochs=100, imgsz=640)
output = model("/home/robotics/dino/img/IMAGE0000016.jpg")  # predict on an image

这样输出的output是没有梯度的,不能训练。想要训练就要调用注释掉的train方法,需要提前按照coco格式准备好数据集,如果不想制作coco数据集的格式,通过这种方法拿出模型的带梯度的输出

from ultralytics import YOLO
import torch
import numpy as np
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt


# Load a model
yoloSeg = YOLO('yolov8x-seg.yaml').load('yolov8x-seg.pt')  # build from YAML and transfer weights

name = "/home/dino/img/student_building/b1.jpg"
img0 = cv2.imread(name)
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
img0 = cv2.resize(img0, (512, 512))

mean = np.array([0.485, 0.456, 0.406]) * 255
std = np.array([0.229, 0.224, 0.225]) * 255
# 对图像进行归一化
img = (img0 - mean) / std

img = torch.from_numpy(img).unsqueeze(0).to(torch.float32).to("cuda")
img = img.permute(0, 3, 1, 2)

yoloSeg.model = yoloSeg.model.to("cuda")

output = yoloSeg.model(img)
result = output[2]

result = F.interpolate(result, size=(512, 512), mode='nearest')

max_indices = torch.argmax(result, dim=1)
result = torch.squeeze(max_indices, dim=1)
result = result.cpu().numpy()
result = result.astype(np.uint8)
result = result.squeeze()

plt.subplot(121)
plt.imshow(img0)
plt.subplot(122)
plt.imshow(result, cmap='viridis')
plt.title('output'), plt.xticks([]), plt.yticks([])
plt.show()

print("done")

调用的函数是ultralytics/nn/tasks.py中 BaseModel的forward方法

对模型进行训练

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from prepare_data import get_dataloader
from ultralytics import YOLO


if __name__ == "__main__":
	device = torch.device("cuda")

	writer = SummaryWriter('./log')  # 指定日志保存的目录
	
	train_loader = get_dataloader()
	
	criterion = nn.CrossEntropyLoss()
	
	yoloSeg = YOLO('yolov8x-seg.yaml').load('yolov8x-seg.pt')  # build from YAML and transfer weights
	model = yoloSeg.model.to("cuda")
	
	optimizer = optim.AdamW(model.parameters(), lr=1e-5)
	
	num_epochs = 300
	for epoch in range(1, num_epochs+1):
		epoch_loss = 0.0
		for batch_idx, (data, target) in enumerate(train_loader):
			data = data.to(device).float()
			data = data.permute(0, 3, 1, 2)
			target = target.to(device).to(torch.int64)
			output = model(data)[2]
			output = F.interpolate(output, size=(512, 512), mode='nearest')
			
			loss = criterion(output, target)
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			
			epoch_loss += loss.item()
			
			if batch_idx % 10 == 0:
				print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
					  .format(epoch, num_epochs, batch_idx+1, len(train_loader), loss.item()))
		
		avg_epoch_loss = epoch_loss / len(train_loader)
		writer.add_scalar('Training Loss', avg_epoch_loss, epoch)
		
		if epoch % 10 == 0:
			torch.save(model.state_dict(), "./my_checkpoints/my_train_temp.pth")
			print(f"Model weights saved.")

	writer.close()

相关推荐

  1. cityscapes fine tune yolov8-seg

    2024-01-28 20:46:04       56 阅读
  2. SVG 字体 – SVG样式(17)

    2024-01-28 20:46:04       51 阅读
  3. sed简说

    2024-01-28 20:46:04       52 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-28 20:46:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-28 20:46:04       101 阅读
  3. 在Django里面运行非项目文件

    2024-01-28 20:46:04       82 阅读
  4. Python语言-面向对象

    2024-01-28 20:46:04       91 阅读

热门阅读

  1. 《动手学深度学习(PyTorch版)》笔记4.9

    2024-01-28 20:46:04       36 阅读
  2. kingbase常用SQL总结之使用率

    2024-01-28 20:46:04       55 阅读
  3. 代码随想录算法训练营29期Day31|LeetCode 455,376,53

    2024-01-28 20:46:04       62 阅读
  4. 【 C++私房菜】模板的入门与进阶

    2024-01-28 20:46:04       45 阅读
  5. FIND_IN_SET的使用:mysql表数据多角色、多用户查询

    2024-01-28 20:46:04       60 阅读