for notes
This commit is contained in:
parent
dabc2495a9
commit
7a071fa658
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,8 +1,10 @@
|
|||||||
./flowers/*
|
./flowers/*
|
||||||
.DS_Store
|
.DS_Store
|
||||||
./UNet/train_image/*
|
UNet/train_image/*
|
||||||
./UNet/params/*
|
UNet/params/*
|
||||||
./UNet/__pycache__/*
|
UNet/__pycache__/*
|
||||||
|
UNet/test_image
|
||||||
data/
|
data/
|
||||||
archive.zip
|
archive.zip
|
||||||
flowers/*
|
flowers/*
|
||||||
|
UNet/result/result.jpg
|
||||||
|
@ -7,9 +7,9 @@ from net import *
|
|||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
weight_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/UNet/params/unet.pth'
|
weight_path = r'D:\\MasterThesis\\UNet\\params\\unet.pth'
|
||||||
data_path = r'/Users/hanzhangma/Document/DataSet/VOC2007'
|
data_path = r'D:\\MasterThesis\\data\\VOCdevkit\\VOC2007'
|
||||||
save_path = r'/Users/hanzhangma/Nextcloud/mhz/Study/SS24/MasterThesis/Unet/train_image'
|
save_path = r'D:\\MasterThesis\\UNet\\train_image'
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True)
|
data_loader = DataLoader(MyDataset(data_path), batch_size= 4, shuffle=True)
|
||||||
|
@ -665,6 +665,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" num_resolutions = len(base_channels_multiples)\n",
|
" num_resolutions = len(base_channels_multiples)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # encoder blocks = resnetblock * 3 + \n",
|
||||||
" self.encoder_blocks = nn.ModuleList()\n",
|
" self.encoder_blocks = nn.ModuleList()\n",
|
||||||
" curr_channels = [base_channels]\n",
|
" curr_channels = [base_channels]\n",
|
||||||
" in_channels = base_channels\n",
|
" in_channels = base_channels\n",
|
||||||
@ -799,6 +800,7 @@
|
|||||||
" self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1-self.alpha_cumulative)\n",
|
" self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1-self.alpha_cumulative)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def get_betas(self):\n",
|
" def get_betas(self):\n",
|
||||||
|
" \"\"\"linear schedule, proposed in original ddpm paper 线性在原ddpm论文中提出\"\"\"\n",
|
||||||
" scale = 1000 / self.num_diffusion_timesteps\n",
|
" scale = 1000 / self.num_diffusion_timesteps\n",
|
||||||
" beta_start = scale * 1e-4\n",
|
" beta_start = scale * 1e-4\n",
|
||||||
" beta_end = scale * 0.02\n",
|
" beta_end = scale * 0.02\n",
|
||||||
@ -896,66 +898,6 @@
|
|||||||
"## Training"
|
"## Training"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 99,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"@dataclass\n",
|
|
||||||
"class ModelConfig:\n",
|
|
||||||
" BASE_CH = 64 # 64, 128, 256, 256\n",
|
|
||||||
" BASE_CH_MULT = (1, 2, 4, 4) # 32, 16, 8, 8 \n",
|
|
||||||
" APPLY_ATTENTION = (False, True, True, False)\n",
|
|
||||||
" DROPOUT_RATE = 0.1\n",
|
|
||||||
" TIME_EMB_MULT = 4 # 128"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 100,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = UNet(\n",
|
|
||||||
" input_channels = TrainingConfig.IMG_SHAPE[0],\n",
|
|
||||||
" output_channels = TrainingConfig.IMG_SHAPE[0],\n",
|
|
||||||
" base_channels = ModelConfig.BASE_CH,\n",
|
|
||||||
" base_channels_multiples = ModelConfig.BASE_CH_MULT,\n",
|
|
||||||
" apply_attention = ModelConfig.APPLY_ATTENTION,\n",
|
|
||||||
" dropout_rate = ModelConfig.DROPOUT_RATE,\n",
|
|
||||||
" time_multiple = ModelConfig.TIME_EMB_MULT,\n",
|
|
||||||
")\n",
|
|
||||||
"model.to(BaseConfig.DEVICE)\n",
|
|
||||||
"\n",
|
|
||||||
"optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig.LR)\n",
|
|
||||||
"\n",
|
|
||||||
"dataloader = get_dataloader(\n",
|
|
||||||
" dataset_name = BaseConfig.DATASET,\n",
|
|
||||||
" batch_size = TrainingConfig.BATCH_SIZE,\n",
|
|
||||||
" device = BaseConfig.DEVICE,\n",
|
|
||||||
" pin_memory = True,\n",
|
|
||||||
" num_workers = TrainingConfig.NUM_WORKERS,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"loss_fn = nn.MSELoss()\n",
|
|
||||||
"\n",
|
|
||||||
"sd = SimpleDiffusion(\n",
|
|
||||||
" num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n",
|
|
||||||
" img_shape = TrainingConfig.IMG_SHAPE,\n",
|
|
||||||
" device = BaseConfig.DEVICE,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"scaler = amp.GradScaler()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 101,
|
"execution_count": 101,
|
||||||
@ -1051,13 +993,16 @@
|
|||||||
" for x0s, _ in loader:\n",
|
" for x0s, _ in loader:\n",
|
||||||
" tq.update(1)\n",
|
" tq.update(1)\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
" # 生成噪声\n",
|
||||||
" ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n",
|
" ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n",
|
||||||
" xts, gt_noise = forward_diffusion(sd, x0s, ts)\n",
|
" xts, gt_noise = forward_diffusion(sd, x0s, ts)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # forward & get loss\n",
|
||||||
" with amp.autocast():\n",
|
" with amp.autocast():\n",
|
||||||
" pred_noise = model(xts, ts)\n",
|
" pred_noise = model(xts, ts)\n",
|
||||||
" loss = loss_fn(gt_noise, pred_noise)\n",
|
" loss = loss_fn(gt_noise, pred_noise)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # 梯度缩放和反向传播\n",
|
||||||
" optimizer.zero_grad(set_to_none=True)\n",
|
" optimizer.zero_grad(set_to_none=True)\n",
|
||||||
" scaler.scale(loss).backward()\n",
|
" scaler.scale(loss).backward()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user