add helper function(1)
This commit is contained in:
parent
0436bbe211
commit
c0593c9371
181
diffusion.ipynb
181
diffusion.ipynb
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -44,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -58,22 +58,177 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DeviceDataLoader:\n",
|
||||
" \"\"\"包装一个数据加载器,来把数据移动到另一个设备上\"\"\"\n",
|
||||
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, dl, device):\n",
|
||||
" self.dl = dl\n",
|
||||
" self.device = device\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" \"\"\"在移动到设备后生成一个批次的数据\"\"\"\n",
|
||||
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
|
||||
" for b in self.dl:\n",
|
||||
" yield to_device(b, self.device)\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" \"\"\"批次的数量\"\"\"\n",
|
||||
" \"\"\"Number of batches\"\"\"\n",
|
||||
" return len(self.dl)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_default_device():\n",
|
||||
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def save_images(images, path, **kwargs):\n",
|
||||
" grid = make_grid(images, **kwargs)\n",
|
||||
" ndarr = grid.permute(1,2,0).to(\"cpu\").numpy()\n",
|
||||
" im = Image.fromarray(ndarr)\n",
|
||||
" im.save(path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get(element: torch.Tensor, t: torch.Tensor):\n",
|
||||
" \"\"\"\n",
|
||||
" Get value at index position \"t\" in \"element\" and \n",
|
||||
" reshape it to have the same dimension as a batch of images\n",
|
||||
"\n",
|
||||
" 获得在\"element\"中位置\"t\"并且reshape,以和一组照片有相同的维度\n",
|
||||
" \"\"\"\n",
|
||||
" ele = element.gather(-1, t)\n",
|
||||
" return ele.reshape(-1, 1, 1, 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"element = torch.tensor([[1,2,3,4,5],\n",
|
||||
" [2,3,4,5,6],\n",
|
||||
" [3,4,5,6,7]])\n",
|
||||
"t = torch.tensor([1,2,0]).unsqueeze(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'get_default_device' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n",
|
||||
"Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m 6\u001b[0m DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n",
|
||||
"\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[1, 2, 3, 4, 5],\n",
|
||||
" [2, 3, 4, 5, 6],\n",
|
||||
" [3, 4, 5, 6, 7]])\n",
|
||||
"tensor([[1],\n",
|
||||
" [2],\n",
|
||||
" [0]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(element)\n",
|
||||
"print(t)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[[[2]]],\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" [[[4]]],\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" [[[3]]]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"extracted_scores = get(element, t)\n",
|
||||
"print(extracted_scores)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def setup_log_directory(config):\n",
|
||||
" \"\"\"\n",
|
||||
" Log and Model checkpoint directory Setup\n",
|
||||
" 记录并且建模目录准备\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" if os.path.isdir(config.root_log_dir):\n",
|
||||
" # Get all folders numbers in the root_log_dir\n",
|
||||
" # 在root_log_dir下获得所有文件夹数目\n",
|
||||
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
|
||||
"\n",
|
||||
" # Find the latest version number present in the log_dir\n",
|
||||
" # 找到在log_dir下的最新版本数字\n",
|
||||
" last_version_number = max(folder_numbers)\n",
|
||||
"\n",
|
||||
" # New version name\n",
|
||||
" version_name = f\"version{last_version_number + 1}\"\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" version_name = config.log_dir\n",
|
||||
"\n",
|
||||
" # Update the training config default directory\n",
|
||||
" # 更新训练config默认目录\n",
|
||||
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
|
||||
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name) \n",
|
||||
"\n",
|
||||
" # Create new directory for saving new experiment version\n",
|
||||
" # 创建一个新目录来保存新的实验版本\n",
|
||||
" os.makedirs(log_dir, exist_ok=True)\n",
|
||||
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
" print(f\"Logging at: {log_dir}\")\n",
|
||||
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
|
||||
"\n",
|
||||
" return log_dir, checkpoint_dir\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
@ -117,7 +272,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
Reference in New Issue
Block a user