add diffusion.ipynb
This commit is contained in:
parent
62e0019201
commit
0436bbe211
125
diffusion.ipynb
Normal file
125
diffusion.ipynb
Normal file
@ -0,0 +1,125 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gc\n",
|
||||
"import os\n",
|
||||
"import cv2\n",
|
||||
"import math\n",
|
||||
"import base64\n",
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"from PIL import Image \n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from datetime import datetime\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.cuda import amp\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.optim import Adam, AdamW\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"import torchvision\n",
|
||||
"import torchvision.transforms as TF\n",
|
||||
"import torchvision.datasets as datasets\n",
|
||||
"from torchvision.utils import make_grid\n",
|
||||
"\n",
|
||||
"from torchmetrics import MeanMetric\n",
|
||||
"\n",
|
||||
"from IPython.display import display, HTML, clear_output\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Helper functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def to_device(data, device):\n",
|
||||
" \"\"\"将张量移动到选择的设备\"\"\"\n",
|
||||
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
|
||||
" if isinstance(data, (list, tuple)):\n",
|
||||
" return [to_device(x, device) for x in data]\n",
|
||||
" return data.to(device, non_blocking=true)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'get_default_device' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m dataclass\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129;43m@dataclass\u001b[39;49m\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mBaseConfig\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mget_default_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mDATASET\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFlowers\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;49;00m\n",
|
||||
"Cell \u001b[1;32mIn[5], line 5\u001b[0m, in \u001b[0;36mBaseConfig\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mBaseConfig\u001b[39;00m:\n\u001b[1;32m----> 5\u001b[0m DEVICE \u001b[38;5;241m=\u001b[39m \u001b[43mget_default_device\u001b[49m()\n\u001b[0;32m 6\u001b[0m DATASET \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFlowers\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m#MNIST \"cifar-10\" \"Flowers\"\u001b[39;00m\n\u001b[0;32m 8\u001b[0m \u001b[38;5;66;03m# 记录推断日志信息并保存存档点\u001b[39;00m\n",
|
||||
"\u001b[1;31mNameError\u001b[0m: name 'get_default_device' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class BaseConfig:\n",
|
||||
" DEVICE = get_default_device()\n",
|
||||
" DATASET = \"Flowers\" #MNIST \"cifar-10\" \"Flowers\"\n",
|
||||
"\n",
|
||||
" # 记录推断日志信息并保存存档点\n",
|
||||
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
|
||||
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\",\"checkpoints\")\n",
|
||||
"\n",
|
||||
" #目前的日志和存档点目录\n",
|
||||
" log_dir = \"version_0\"\n",
|
||||
" checkpoint_dir = \"version_0\"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class TrainingConfig:\n",
|
||||
" TIMESTEPS = 1000\n",
|
||||
" IMG_SHAPE = (1,32,32) if BaseConfig.DATASET == \"MNIST\" else (3,32,32)\n",
|
||||
" NUM_EPOCHS = 800\n",
|
||||
" BATCH_SIZE = 32\n",
|
||||
" LR = 2e-4\n",
|
||||
" NUM_WORKERS = 2"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "DLML",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user