add diffusion.ipynb

This commit is contained in:
Hanzhang ma 2024-03-31 14:51:57 +02:00
parent 62e0019201
commit 0436bbe211

125
diffusion.ipynb Normal file
View File

@ -0,0 +1,125 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import gc\n",
"import os\n",
"import cv2\n",
"import math\n",
"import base64\n",
"import random\n",
"import numpy as np\n",
"from PIL import Image \n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.cuda import amp\n",
"import torch.nn.functional as F\n",
"from torch.optim import Adam, AdamW\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as TF\n",
"import torchvision.datasets as datasets\n",
"from torchvision.utils import make_grid\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from IPython.display import display, HTML, clear_output\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_device(data, device):\n",
" \"\"\"将张量移动到选择的设备\"\"\"\n",
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
" if isinstance(data, (list, tuple)):\n",
" return [to_device(x, device) for x in data]\n",
" return data.to(device, non_blocking=true)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'get_default_device' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n",
"Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m 6\u001b[0m DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n",
"\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined"
]
}
],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class BaseConfig:\n",
" DEVICE = get_default_device()\n",
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
"\n",
" # 记录推断日志信息并保存存档点\n",
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
"\n",
" #目前的日志和存档点目录\n",
" log_dir = \"version_0\"\n",
" checkpoint_dir = \"version_0\"\n",
"\n",
"@dataclass\n",
"class TrainingConfig:\n",
" TIMESTEPS = 1000\n",
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
" NUM_EPOCHS = 800\n",
" BATCH_SIZE = 32\n",
" LR = 2e-4\n",
" NUM_WORKERS = 2"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "DLML",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}