CC BY 4.0 (除特别声明或转载文章外)
如果这篇博客帮助到你,可以请我喝一杯咖啡~
比赛链接https://tianchi.aliyun.com/competition/entrance/532086/introduction
从官网下载到数据集
对于语义分割任务,常用的数据增强操作有很多,例如:几何增强,水平翻转,平移,随即裁剪和缩放,纹理增强,运动模糊,颜色抖动等,这里就不贴代码了
引入所依赖的库
import torch
import os
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import cv2
dataset部分
transform = transforms.Compose([
transforms.ToTensor() # 转化为Tensor
])
# 首先继承Dataset写一个对于数据进行读入和处理的方式
class MyDataset(Dataset):
def __init__(self, path):
self.mode = ('train' if 'mask' in os.listdir(path) else 'test') # 表示训练模式
self.path = path # 图片路径
dirlist = os.listdir(path + 'image/') # 图片的名称
self.name = [n for n in dirlist if n[-3:] == 'png'] # 只读取png图片,如果有其他格式可进行修改
def __len__(self):
return len(self.name)
def __getitem__(self, index): # 获取数据的处理方式
name = self.name[index]
# 读取原始图片和标签
if self.mode == 'train': # 训练模式
ori_img = cv2.imread(self.path + 'image/' + name) # 原始图片
lb_img = cv2.imread(self.path + 'mask/' + name) # mask图片
ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) # 转为RGB三通道图
lb_img = cv2.cvtColor(lb_img, cv2.COLOR_BGR2GRAY) # 掩膜转为灰度图
return transform(ori_img), transform(lb_img)
if self.mode == 'test': # 测试模式
ori_img = cv2.imread(self.path + 'image/' + name) # 原始图片
ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) # 转为RGB三通道图
return transform(ori_img)
# 加载数据集
train_path = 'train/'
traindata = MyDataset(train_path)
定义训练的模型
这里使用的模型是Unet,我们直接使用segmentation_models_pytorch中打包好的模型
model_path='models/'
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
#学习率
lr=3e-4
#学习率衰减
weight_decay=1e-4
#批大小
bs=16
#训练轮次
epochs=20
import torchvision
import torch.nn as nn
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet152",
encoder_weights='imagenet',
in_channels=3,
classes=1,
)
#训练前准备
from torch.utils.data import DataLoader
#加载模型到gpu或cpu
model.to(device)
#使用Binary CrossEntropy作为损失函数,主要处理二分类问题
# BCEloss=nn.BCELoss()
#加载优化器,使用Adam,可以按需求选择其他优化器
optim=torch.optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
#使用traindata创建dataloader对象
trainloader=DataLoader(traindata,batch_size=bs, shuffle=True, num_workers=1)
#根据赛题评测选用dice_loss,这个是开源代码
def dice_loss(logits, target):
smooth = 1.
prob = torch.sigmoid(logits)
batch = prob.size(0)
prob = prob.view(batch,1,-1)
target = target.view(batch,1,-1)
intersection = torch.sum(prob*target, dim=2)
denominator = torch.sum(prob, dim=2) + torch.sum(target, dim=2)
dice = (2*intersection + smooth) / (denominator + smooth)
dice = torch.mean(dice)
dice_loss = 1. - dice
return dice_loss
loss_last=99999
best_model_name='x'
#记录loss变化
for epoch in range(1,epochs+1):
for step,(inputs,labels) in tqdm(enumerate(trainloader),desc=f"Epoch {epoch}/{epochs}",
ascii=True, total=len(trainloader)):
#原始图片和标签
inputs, labels = inputs.to(device), labels.to(device)
out = model(inputs)
loss = dice_loss(out, labels)
# 后向
optim.zero_grad()
#梯度反向传播
loss.backward()
optim.step()
#损失小于上一轮则添加
if loss<loss_last:
loss_last=loss
torch.save(model.state_dict(),model_path+'model_epoch{}_loss{}.pth'.format(epoch,loss))
best_model_name=model_path+'model_epoch{}_loss{}.pth'.format(epoch,loss)
print(f"\nEpoch: {epoch}/{epochs},DiceLoss:{loss}") #打印loss
对测试集图片进行处理并提交测评
选出loss尽量低的权重进行测试
model.load_state_dict(torch.load('models/model_epoch7_lossxxxxxxxx.pth'))
#加载测试集
test_path='test/'
testdata=MyDataset(test_path)
#测试模型的预测效果
x=np.random.randint(0,500)
inputs=testdata[x].to(device)
with torch.no_grad():
# 模型预测
t = model(inputs.view(1,3,320,640))
plt.subplot(1,2,1)
plt.imshow(testdata[x].permute(1,2,0))
#对预测的图片采取一定的阈值进行分类
threshold=0.5
t= torch.where(t >=threshold, torch.tensor(255,dtype=torch.float).to(device), t)
t= torch.where(t < threshold, torch.tensor(0,dtype=torch.float).to(device), t)
t=t.cpu().view(1,320,640)
plt.subplot(1,2,2)
plt.imshow(t.permute(1,2,0))
from torchvision.utils import save_image
from PIL import Image
img_save_path='infers/'
for i,inputs in tqdm(enumerate(testdata)):
#原始图片和标签
inputs=inputs.reshape(1,3,320,640).to(device)
# 输出生成的图像
out = model(inputs.view(1,3,320,640)) # 模型预测
#对输出的图像进行后处理
threshold=0.5
out= torch.where(out >=threshold, torch.tensor(255,dtype=torch.float).to(device),out)
out= torch.where(out < threshold, torch.tensor(0,dtype=torch.float).to(device),out)
#保存图像
out= out.detach().cpu().numpy().reshape(1,320,640)
#注意保存为1位图提交
img = Image.fromarray(out[0].astype(np.uint8))
img = img.convert('1')
img.save(img_save_path + testdata.name[i])
查看生成的mask的效果,将图片压缩并提交