Compare commits

...

2 Commits

Author SHA1 Message Date
2cad997a27 update Unet 2024-04-15 13:46:41 +02:00
368e2f7dc2 update gitignore for merge 2024-04-15 13:42:25 +02:00
5 changed files with 183 additions and 180 deletions

5
.gitignore vendored
View File

@ -2,4 +2,7 @@
.DS_Store .DS_Store
./UNet/train_image/* ./UNet/train_image/*
./UNet/params/* ./UNet/params/*
./UNet/__pycache__/* ./UNet/__pycache__/*
data/
archive.zip
flowers/*

View File

@ -1,31 +1,31 @@
import os import os
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils import * from utils import *
from torchvision import transforms from torchvision import transforms
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor() transforms.ToTensor()
]) ])
#use VOC2007 Dataset #use VOC2007 Dataset
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
self.name = os.listdir(os.path.join(path, 'SegmentationClass')) self.name = os.listdir(os.path.join(path, 'SegmentationClass'))
def __len__(self): def __len__(self):
return len(self.name) return len(self.name)
def __getitem__(self, index): def __getitem__(self, index):
segment_name = self.name[index] #xx.png segment_name = self.name[index] #xx.png
segment_path = os.path.join(self.path, 'SegmentationClass',segment_name) segment_path = os.path.join(self.path, 'SegmentationClass',segment_name)
image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg')) image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg'))
segment_image = keep_image_size_open(segment_path) segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open(image_path) image = keep_image_size_open(image_path)
return transform(image), transform(segment_image) return transform(image), transform(segment_image)
if __name__ == '__main__': if __name__ == '__main__':
data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007') data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007')
print(data[0][0].shape) # print the size of image(0,0) print(data[0][0].shape) # print the size of image(0,0)
print(data[0][1].shape) # print the size of image(0,1) print(data[0][1].shape) # print the size of image(0,1)

View File

@ -1,87 +1,87 @@
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch import randn from torch import randn
import torch import torch
class Conv_Block(nn.Module): class Conv_Block(nn.Module):
def __init__(self, in_channel, out_channel): def __init__(self, in_channel, out_channel):
super(Conv_Block, self).__init__() super(Conv_Block, self).__init__()
self.layer = nn.Sequential( self.layer = nn.Sequential(
nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False), nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel), nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3), nn.Dropout2d(0.3),
nn.LeakyReLU(), nn.LeakyReLU(),
nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,stride=1,padding=1,padding_mode='reflect', bias=False), nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,stride=1,padding=1,padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel), nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3), nn.Dropout2d(0.3),
nn.LeakyReLU() nn.LeakyReLU()
) )
def forward(self, x): def forward(self, x):
return self.layer(x) return self.layer(x)
class DownSample(nn.Module): class DownSample(nn.Module):
def __init__(self, channel): def __init__(self, channel):
super(DownSample, self).__init__() super(DownSample, self).__init__()
self.layer = nn.Sequential( self.layer = nn.Sequential(
nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect', bias=False), nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(channel), nn.BatchNorm2d(channel),
nn.LeakyReLU() nn.LeakyReLU()
) )
def forward(self, x): def forward(self, x):
return self.layer(x) return self.layer(x)
class UpSample(nn.Module): class UpSample(nn.Module):
def __init__(self, channel): def __init__(self, channel):
super(UpSample, self).__init__() super(UpSample, self).__init__()
self.layer = nn.Sequential( self.layer = nn.Sequential(
nn.Conv2d(channel, channel//2, 1, 1) nn.Conv2d(channel, channel//2, 1, 1)
) )
def forward(self, x, feature_map): def forward(self, x, feature_map):
up = F.interpolate(x, scale_factor=2, mode='nearest') up = F.interpolate(x, scale_factor=2, mode='nearest')
out = self.layer(up) out = self.layer(up)
return torch.cat((out, feature_map), dim=1) return torch.cat((out, feature_map), dim=1)
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self): def __init__(self):
super(UNet, self).__init__() super(UNet, self).__init__()
self.c1 = Conv_Block(3,64) self.c1 = Conv_Block(3,64)
self.d1 = DownSample(64) self.d1 = DownSample(64)
self.c2 = Conv_Block(64, 128) self.c2 = Conv_Block(64, 128)
self.d2 = DownSample(128) self.d2 = DownSample(128)
self.c3 = Conv_Block(128, 256) self.c3 = Conv_Block(128, 256)
self.d3 = DownSample(256) self.d3 = DownSample(256)
self.c4 = Conv_Block(256, 512) self.c4 = Conv_Block(256, 512)
self.d4 = DownSample(512) self.d4 = DownSample(512)
self.c5 = Conv_Block(512, 1024) self.c5 = Conv_Block(512, 1024)
self.u1 = UpSample(1024) self.u1 = UpSample(1024)
self.c6 = Conv_Block(1024, 512) self.c6 = Conv_Block(1024, 512)
self.u2 = UpSample(512) self.u2 = UpSample(512)
self.c7 = Conv_Block(512, 256) self.c7 = Conv_Block(512, 256)
self.u3 = UpSample(256) self.u3 = UpSample(256)
self.c8 = Conv_Block(256, 128) self.c8 = Conv_Block(256, 128)
self.u4 = UpSample(128) self.u4 = UpSample(128)
self.c9 = Conv_Block(128, 64) self.c9 = Conv_Block(128, 64)
self.out = nn.Conv2d(64, 3, 3, 1, 1) self.out = nn.Conv2d(64, 3, 3, 1, 1)
self.Th = nn.Sigmoid() self.Th = nn.Sigmoid()
def forward(self, x): def forward(self, x):
R1 = self.c1(x) R1 = self.c1(x)
R2 = self.c2(self.d1(R1)) R2 = self.c2(self.d1(R1))
R3 = self.c3(self.d2(R2)) R3 = self.c3(self.d2(R2))
R4 = self.c4(self.d3(R3)) R4 = self.c4(self.d3(R3))
R5 = self.c5(self.d4(R4)) R5 = self.c5(self.d4(R4))
O1 = self.c6(self.u1(R5, R4)) O1 = self.c6(self.u1(R5, R4))
O2 = self.c7(self.u2(O1, R3)) O2 = self.c7(self.u2(O1, R3))
O3 = self.c8(self.u3(O2, R2)) O3 = self.c8(self.u3(O2, R2))
O4 = self.c9(self.u4(O3, R1)) O4 = self.c9(self.u4(O3, R1))
return self.Th(self.out(O4)) return self.Th(self.out(O4))
if __name__ == '__main__': if __name__ == '__main__':
x = randn(2, 3, 256, 256) x = randn(2, 3, 256, 256)
net = UNet() net = UNet()
print(net(x).shape) print(net(x).shape)

View File

@ -1,53 +1,53 @@
import torch import torch
from torch import optim from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from data import * from data import *
from net import * from net import *
from torchvision.utils import save_image from torchvision.utils import save_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth' weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth'
data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007' data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007'
save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image' save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image'
if __name__ == '__main__': if __name__ == '__main__':
data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True) data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True)
net = UNet().to(device) net = UNet().to(device)
if os.path.exists(weight_path): if os.path.exists(weight_path):
net.load_state_dict(torch.load(weight_path)) net.load_state_dict(torch.load(weight_path))
print('successful load weight!') print('successful load weight!')
else: else:
print('Failed on load weight!') print('Failed on load weight!')
opt = optim.Adam(net.parameters()) opt = optim.Adam(net.parameters())
loss_fun = nn.BCELoss() loss_fun = nn.BCELoss()
epoch=1 epoch=1
while True: while True:
for i,(image,segment_image) in enumerate(data_loader): for i,(image,segment_image) in enumerate(data_loader):
image, segment_image = image.to(device), segment_image.to(device) image, segment_image = image.to(device), segment_image.to(device)
out_image = net(image) out_image = net(image)
train_loss = loss_fun(out_image, segment_image) train_loss = loss_fun(out_image, segment_image)
opt.zero_grad() opt.zero_grad()
train_loss.backward() train_loss.backward()
opt.step() # 更新梯度 opt.step() # 更新梯度
if i%5 ==0 : if i%5 ==0 :
print(f'{epoch} -- {i} -- train loss ===>> {train_loss.item()}') print(f'{epoch} -- {i} -- train loss ===>> {train_loss.item()}')
if i % 50 == 0: if i % 50 == 0:
torch.save(net.state_dict(), weight_path) torch.save(net.state_dict(), weight_path)
_image = image[0] _image = image[0]
_segment_image = segment_image[0] _segment_image = segment_image[0]
_out_image = out_image[0] _out_image = out_image[0]
img = torch.stack([_image, _segment_image, _out_image], dim=0) img = torch.stack([_image, _segment_image, _out_image], dim=0)
save_image(img, f'{save_path}/{i}.png') save_image(img, f'{save_path}/{i}.png')
epoch += 1 epoch += 1

View File

@ -1,10 +1,10 @@
from PIL import Image from PIL import Image
def keep_image_size_open(path,size=(256,256)): def keep_image_size_open(path,size=(256,256)):
img = Image.open(path) img = Image.open(path)
tmp = max(img.size) tmp = max(img.size)
mask = Image.new('RGB', (tmp, tmp),(0,0,0)) mask = Image.new('RGB', (tmp, tmp),(0,0,0))
mask.paste(img,(0,0)) mask.paste(img,(0,0))
mask = mask.resize(size) mask = mask.resize(size)
return mask return mask