MasterThesis/diffusion.ipynb
2024-04-01 00:16:59 +02:00

281 lines
7.6 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import gc\n",
"import os\n",
"import cv2\n",
"import math\n",
"import base64\n",
"import random\n",
"import numpy as np\n",
"from PIL import Image \n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.cuda import amp\n",
"import torch.nn.functional as F\n",
"from torch.optim import Adam, AdamW\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as TF\n",
"import torchvision.datasets as datasets\n",
"from torchvision.utils import make_grid\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from IPython.display import display, HTML, clear_output\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def to_device(data, device):\n",
" \"\"\"将张量移动到选择的设备\"\"\"\n",
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
" if isinstance(data, (list, tuple)):\n",
" return [to_device(x, device) for x in data]\n",
" return data.to(device, non_blocking=true)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class DeviceDataLoader:\n",
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
"\n",
" def __init__(self, dl, device):\n",
" self.dl = dl\n",
" self.device = device\n",
"\n",
" def __iter__(self):\n",
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
" for b in self.dl:\n",
" yield to_device(b, self.device)\n",
"\n",
" def __len__(self):\n",
" \"\"\"批次的数量\"\"\"\n",
" \"\"\"Number of batches\"\"\"\n",
" return len(self.dl)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def get_default_device():\n",
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def save_images(images, path, **kwargs):\n",
" grid = make_grid(images, **kwargs)\n",
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
" im = Image.fromarray(ndarr)\n",
" im.save(path)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def get(element: torch.Tensor, t: torch.Tensor):\n",
" \"\"\"\n",
" Get value at index position \"t\" in \"element\" and \n",
" reshape it to have the same dimension as a batch of images\n",
"\n",
" 获得在\"element\"中位置\"t\"并且reshape以和一组照片有相同的维度\n",
" \"\"\"\n",
" ele = element.gather(-1, t)\n",
" return ele.reshape(-1, 1, 1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"element = torch.tensor([[1,2,3,4,5],\n",
" [2,3,4,5,6],\n",
" [3,4,5,6,7]])\n",
"t = torch.tensor([1,2,0]).unsqueeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1, 2, 3, 4, 5],\n",
" [2, 3, 4, 5, 6],\n",
" [3, 4, 5, 6, 7]])\n",
"tensor([[1],\n",
" [2],\n",
" [0]])\n"
]
}
],
"source": [
"print(element)\n",
"print(t)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[2]]],\n",
"\n",
"\n",
" [[[4]]],\n",
"\n",
"\n",
" [[[3]]]])\n"
]
}
],
"source": [
"extracted_scores = get(element, t)\n",
"print(extracted_scores)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def setup_log_directory(config):\n",
" \"\"\"\n",
" Log and Model checkpoint directory Setup\n",
" 记录并且建模目录准备\n",
" \"\"\"\n",
"\n",
" if os.path.isdir(config.root_log_dir):\n",
" # Get all folders numbers in the root_log_dir\n",
" # 在root_log_dir下获得所有文件夹数目\n",
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
"\n",
" # Find the latest version number present in the log_dir\n",
" # 找到在log_dir下的最新版本数字\n",
" last_version_number = max(folder_numbers)\n",
"\n",
" # New version name\n",
" version_name = f\"version{last_version_number + 1}\"\n",
"\n",
" else:\n",
" version_name = config.log_dir\n",
"\n",
" # Update the training config default directory\n",
" # 更新训练config默认目录\n",
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
"\n",
" # Create new directory for saving new experiment version\n",
" # 创建一个新目录来保存新的实验版本\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
" print(f\"Logging at: {log_dir}\")\n",
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
"\n",
" return log_dir, checkpoint_dir\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class BaseConfig:\n",
" DEVICE = get_default_device()\n",
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
"\n",
" # 记录推断日志信息并保存存档点\n",
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
"\n",
" #目前的日志和存档点目录\n",
" log_dir = \"version_0\"\n",
" checkpoint_dir = \"version_0\"\n",
"\n",
"@dataclass\n",
"class TrainingConfig:\n",
" TIMESTEPS = 1000\n",
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
" NUM_EPOCHS = 800\n",
" BATCH_SIZE = 32\n",
" LR = 2e-4\n",
" NUM_WORKERS = 2"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "DLML",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}