281 lines
7.6 KiB
Plaintext
281 lines
7.6 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import gc\n",
|
||
"import os\n",
|
||
"import cv2\n",
|
||
"import math\n",
|
||
"import base64\n",
|
||
"import random\n",
|
||
"import numpy as np\n",
|
||
"from PIL import Image \n",
|
||
"from tqdm import tqdm\n",
|
||
"from datetime import datetime\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"from torch.cuda import amp\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from torch.optim import Adam, AdamW\n",
|
||
"from torch.utils.data import Dataset, DataLoader\n",
|
||
"\n",
|
||
"import torchvision\n",
|
||
"import torchvision.transforms as TF\n",
|
||
"import torchvision.datasets as datasets\n",
|
||
"from torchvision.utils import make_grid\n",
|
||
"\n",
|
||
"from torchmetrics import MeanMetric\n",
|
||
"\n",
|
||
"from IPython.display import display, HTML, clear_output\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Helper functions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def to_device(data, device):\n",
|
||
" \"\"\"将张量移动到选择的设备\"\"\"\n",
|
||
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
|
||
" if isinstance(data, (list, tuple)):\n",
|
||
" return [to_device(x, device) for x in data]\n",
|
||
" return data.to(device, non_blocking=true)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class DeviceDataLoader:\n",
|
||
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
|
||
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, dl, device):\n",
|
||
" self.dl = dl\n",
|
||
" self.device = device\n",
|
||
"\n",
|
||
" def __iter__(self):\n",
|
||
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
|
||
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
|
||
" for b in self.dl:\n",
|
||
" yield to_device(b, self.device)\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" \"\"\"批次的数量\"\"\"\n",
|
||
" \"\"\"Number of batches\"\"\"\n",
|
||
" return len(self.dl)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_default_device():\n",
|
||
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def save_images(images, path, **kwargs):\n",
|
||
" grid = make_grid(images, **kwargs)\n",
|
||
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
|
||
" im = Image.fromarray(ndarr)\n",
|
||
" im.save(path)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get(element: torch.Tensor, t: torch.Tensor):\n",
|
||
" \"\"\"\n",
|
||
" Get value at index position \"t\" in \"element\" and \n",
|
||
" reshape it to have the same dimension as a batch of images\n",
|
||
"\n",
|
||
" 获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n",
|
||
" \"\"\"\n",
|
||
" ele = element.gather(-1, t)\n",
|
||
" return ele.reshape(-1, 1, 1, 1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"element = torch.tensor([[1,2,3,4,5],\n",
|
||
" [2,3,4,5,6],\n",
|
||
" [3,4,5,6,7]])\n",
|
||
"t = torch.tensor([1,2,0]).unsqueeze(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"tensor([[1, 2, 3, 4, 5],\n",
|
||
" [2, 3, 4, 5, 6],\n",
|
||
" [3, 4, 5, 6, 7]])\n",
|
||
"tensor([[1],\n",
|
||
" [2],\n",
|
||
" [0]])\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(element)\n",
|
||
"print(t)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"tensor([[[[2]]],\n",
|
||
"\n",
|
||
"\n",
|
||
" [[[4]]],\n",
|
||
"\n",
|
||
"\n",
|
||
" [[[3]]]])\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"extracted_scores = get(element, t)\n",
|
||
"print(extracted_scores)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def setup_log_directory(config):\n",
|
||
" \"\"\"\n",
|
||
" Log and Model checkpoint directory Setup\n",
|
||
" 记录并且建模目录准备\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" if os.path.isdir(config.root_log_dir):\n",
|
||
" # Get all folders numbers in the root_log_dir\n",
|
||
" # 在root_log_dir下获得所有文件夹数目\n",
|
||
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
|
||
"\n",
|
||
" # Find the latest version number present in the log_dir\n",
|
||
" # 找到在log_dir下的最新版本数字\n",
|
||
" last_version_number = max(folder_numbers)\n",
|
||
"\n",
|
||
" # New version name\n",
|
||
" version_name = f\"version{last_version_number + 1}\"\n",
|
||
"\n",
|
||
" else:\n",
|
||
" version_name = config.log_dir\n",
|
||
"\n",
|
||
" # Update the training config default directory\n",
|
||
" # 更新训练config默认目录\n",
|
||
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
|
||
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
|
||
"\n",
|
||
" # Create new directory for saving new experiment version\n",
|
||
" # 创建一个新目录来保存新的实验版本\n",
|
||
" os.makedirs(log_dir, exist_ok=True)\n",
|
||
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
|
||
"\n",
|
||
" print(f\"Logging at: {log_dir}\")\n",
|
||
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
|
||
"\n",
|
||
" return log_dir, checkpoint_dir\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from dataclasses import dataclass\n",
|
||
"\n",
|
||
"@dataclass\n",
|
||
"class BaseConfig:\n",
|
||
" DEVICE = get_default_device()\n",
|
||
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
|
||
"\n",
|
||
" # 记录推断日志信息并保存存档点\n",
|
||
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
|
||
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
|
||
"\n",
|
||
" #目前的日志和存档点目录\n",
|
||
" log_dir = \"version_0\"\n",
|
||
" checkpoint_dir = \"version_0\"\n",
|
||
"\n",
|
||
"@dataclass\n",
|
||
"class TrainingConfig:\n",
|
||
" TIMESTEPS = 1000\n",
|
||
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
|
||
" NUM_EPOCHS = 800\n",
|
||
" BATCH_SIZE = 32\n",
|
||
" LR = 2e-4\n",
|
||
" NUM_WORKERS = 2"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "DLML",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.8"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|