MasterThesis/Generating_MNIST_using_DDPMs.ipynb

2015 lines
4.2 MiB
Plaintext
Raw Normal View History

2024-04-08 11:37:01 +02:00
{
"cells": [
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 6,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.034694Z",
"start_time": "2023-02-23T07:34:50.200674Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:33.905754Z",
"iopub.status.busy": "2023-02-22T16:01:33.904879Z",
"iopub.status.idle": "2023-02-22T16:01:36.156631Z",
"shell.execute_reply": "2023-02-22T16:01:36.155423Z",
"shell.execute_reply.started": "2023-02-22T16:01:33.905665Z"
}
},
2024-04-09 09:31:18 +02:00
"outputs": [],
2024-04-08 11:37:01 +02:00
"source": [
"import gc\n",
"import os\n",
"import cv2\n",
"import math\n",
"import base64\n",
"import random\n",
"import numpy as np\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.cuda import amp\n",
"import torch.nn.functional as F\n",
"from torch.optim import Adam, AdamW\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as TF\n",
"import torchvision.datasets as datasets\n",
"from torchvision.utils import make_grid\n",
"\n",
"from torchmetrics import MeanMetric\n",
"\n",
"from IPython.display import display, HTML, clear_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# UNet Model"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 7,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.142674Z",
"start_time": "2023-02-23T07:34:53.114668Z"
},
"code_folding": [
0,
1,
25,
105
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.165052Z",
"iopub.status.busy": "2023-02-22T16:01:36.164311Z",
"iopub.status.idle": "2023-02-22T16:01:36.199528Z",
"shell.execute_reply": "2023-02-22T16:01:36.198080Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.165002Z"
}
},
"outputs": [],
"source": [
"class SinusoidalPositionEmbeddings(nn.Module):\n",
" def __init__(self, total_time_steps=1000, time_emb_dims=128, time_emb_dims_exp=512):\n",
" super().__init__()\n",
"\n",
" half_dim = time_emb_dims // 2\n",
"\n",
" emb = math.log(10000) / (half_dim - 1)\n",
" emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n",
"\n",
" ts = torch.arange(total_time_steps, dtype=torch.float32)\n",
"\n",
" emb = torch.unsqueeze(ts, dim=-1) * torch.unsqueeze(emb, dim=0)\n",
" emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n",
"\n",
" self.time_blocks = nn.Sequential(\n",
" nn.Embedding.from_pretrained(emb),\n",
" nn.Linear(in_features=time_emb_dims, out_features=time_emb_dims_exp),\n",
" nn.SiLU(),\n",
" nn.Linear(in_features=time_emb_dims_exp, out_features=time_emb_dims_exp),\n",
" )\n",
"\n",
" def forward(self, time):\n",
" return self.time_blocks(time)\n",
"\n",
"\n",
"class AttentionBlock(nn.Module):\n",
" def __init__(self, channels=64):\n",
" super().__init__()\n",
" self.channels = channels\n",
"\n",
" self.group_norm = nn.GroupNorm(num_groups=8, num_channels=channels)\n",
" self.mhsa = nn.MultiheadAttention(embed_dim=self.channels, num_heads=4, batch_first=True)\n",
"\n",
" def forward(self, x):\n",
" B, _, H, W = x.shape\n",
" h = self.group_norm(x)\n",
" h = h.reshape(B, self.channels, H * W).swapaxes(1, 2) # [B, C, H, W] --> [B, C, H * W] --> [B, H*W, C]\n",
" h, _ = self.mhsa(h, h, h) # [B, H*W, C]\n",
" h = h.swapaxes(2, 1).view(B, self.channels, H, W) # [B, C, H*W] --> [B, C, H, W]\n",
" return x + h\n",
"\n",
"\n",
"class ResnetBlock(nn.Module):\n",
" def __init__(self, *, in_channels, out_channels, dropout_rate=0.1, time_emb_dims=512, apply_attention=False):\n",
" super().__init__()\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
"\n",
" self.act_fn = nn.SiLU()\n",
" # Group 1\n",
" self.normlize_1 = nn.GroupNorm(num_groups=8, num_channels=self.in_channels)\n",
" self.conv_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" # Group 2 time embedding\n",
" self.dense_1 = nn.Linear(in_features=time_emb_dims, out_features=self.out_channels)\n",
"\n",
" # Group 3\n",
" self.normlize_2 = nn.GroupNorm(num_groups=8, num_channels=self.out_channels)\n",
" self.dropout = nn.Dropout2d(p=dropout_rate)\n",
" self.conv_2 = nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" if self.in_channels != self.out_channels:\n",
" self.match_input = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1)\n",
" else:\n",
" self.match_input = nn.Identity()\n",
"\n",
" if apply_attention:\n",
" self.attention = AttentionBlock(channels=self.out_channels)\n",
" else:\n",
" self.attention = nn.Identity()\n",
"\n",
" def forward(self, x, t):\n",
" # group 1\n",
" h = self.act_fn(self.normlize_1(x))\n",
" h = self.conv_1(h)\n",
"\n",
" # group 2\n",
" # add in timestep embedding\n",
" h += self.dense_1(self.act_fn(t))[:, :, None, None]\n",
"\n",
" # group 3\n",
" h = self.act_fn(self.normlize_2(h))\n",
" h = self.dropout(h)\n",
" h = self.conv_2(h)\n",
"\n",
" # Residual and attention\n",
" h = h + self.match_input(x)\n",
" h = self.attention(h)\n",
"\n",
" return h\n",
"\n",
"\n",
"class DownSample(nn.Module):\n",
" def __init__(self, channels):\n",
" super().__init__()\n",
" self.downsample = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)\n",
"\n",
" def forward(self, x, *args):\n",
" return self.downsample(x)\n",
"\n",
"\n",
"class UpSample(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super().__init__()\n",
"\n",
" self.upsample = nn.Sequential(\n",
" nn.Upsample(scale_factor=2, mode=\"nearest\"),\n",
" nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),\n",
" )\n",
"\n",
" def forward(self, x, *args):\n",
" return self.upsample(x)\n",
"\n",
"\n",
"class UNet(nn.Module):\n",
" def __init__(\n",
" self,\n",
" input_channels=3,\n",
" output_channels=3,\n",
" num_res_blocks=2,\n",
" base_channels=128,\n",
" base_channels_multiples=(1, 2, 4, 8),\n",
" apply_attention=(False, False, True, False),\n",
" dropout_rate=0.1,\n",
" time_multiple=4,\n",
" ):\n",
" super().__init__()\n",
"\n",
" time_emb_dims_exp = base_channels * time_multiple\n",
" self.time_embeddings = SinusoidalPositionEmbeddings(time_emb_dims=base_channels, time_emb_dims_exp=time_emb_dims_exp)\n",
"\n",
" self.first = nn.Conv2d(in_channels=input_channels, out_channels=base_channels, kernel_size=3, stride=1, padding=\"same\")\n",
"\n",
" num_resolutions = len(base_channels_multiples)\n",
"\n",
" # Encoder part of the UNet. Dimension reduction.\n",
" self.encoder_blocks = nn.ModuleList()\n",
" curr_channels = [base_channels]\n",
" in_channels = base_channels\n",
"\n",
" for level in range(num_resolutions):\n",
" out_channels = base_channels * base_channels_multiples[level]\n",
"\n",
" for _ in range(num_res_blocks):\n",
"\n",
" block = ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=out_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=apply_attention[level],\n",
" )\n",
" self.encoder_blocks.append(block)\n",
"\n",
" in_channels = out_channels\n",
" curr_channels.append(in_channels)\n",
"\n",
" if level != (num_resolutions - 1):\n",
" self.encoder_blocks.append(DownSample(channels=in_channels))\n",
" curr_channels.append(in_channels)\n",
"\n",
" # Bottleneck in between\n",
" self.bottleneck_blocks = nn.ModuleList(\n",
" (\n",
" ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=in_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=True,\n",
" ),\n",
" ResnetBlock(\n",
" in_channels=in_channels,\n",
" out_channels=in_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=False,\n",
" ),\n",
" )\n",
" )\n",
"\n",
" # Decoder part of the UNet. Dimension restoration with skip-connections.\n",
" self.decoder_blocks = nn.ModuleList()\n",
"\n",
" for level in reversed(range(num_resolutions)):\n",
" out_channels = base_channels * base_channels_multiples[level]\n",
"\n",
" for _ in range(num_res_blocks + 1):\n",
" encoder_in_channels = curr_channels.pop()\n",
" block = ResnetBlock(\n",
" in_channels=encoder_in_channels + in_channels,\n",
" out_channels=out_channels,\n",
" dropout_rate=dropout_rate,\n",
" time_emb_dims=time_emb_dims_exp,\n",
" apply_attention=apply_attention[level],\n",
" )\n",
"\n",
" in_channels = out_channels\n",
" self.decoder_blocks.append(block)\n",
"\n",
" if level != 0:\n",
" self.decoder_blocks.append(UpSample(in_channels))\n",
"\n",
" self.final = nn.Sequential(\n",
" nn.GroupNorm(num_groups=8, num_channels=in_channels),\n",
" nn.SiLU(),\n",
" nn.Conv2d(in_channels=in_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=\"same\"),\n",
" )\n",
"\n",
" def forward(self, x, t):\n",
"\n",
" time_emb = self.time_embeddings(t)\n",
"\n",
" h = self.first(x)\n",
" outs = [h]\n",
"\n",
" for layer in self.encoder_blocks:\n",
" h = layer(h, time_emb)\n",
" outs.append(h)\n",
"\n",
" for layer in self.bottleneck_blocks:\n",
" h = layer(h, time_emb)\n",
"\n",
" for layer in self.decoder_blocks:\n",
" if isinstance(layer, ResnetBlock):\n",
" out = outs.pop()\n",
" h = torch.cat([h, out], dim=1)\n",
" h = layer(h, time_emb)\n",
"\n",
" h = self.final(h)\n",
"\n",
" return h"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Helper Functions"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 8,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:53.203694Z",
"start_time": "2023-02-23T07:34:53.177669Z"
},
"code_folding": [
0,
25,
31,
39,
68
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.201566Z",
"iopub.status.busy": "2023-02-22T16:01:36.201193Z",
"iopub.status.idle": "2023-02-22T16:01:36.219110Z",
"shell.execute_reply": "2023-02-22T16:01:36.218110Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.201528Z"
}
},
"outputs": [],
"source": [
"def to_device(data, device):\n",
" \"\"\"Move tensor(s) to chosen device\"\"\"\n",
" if isinstance(data, (list, tuple)):\n",
" return [to_device(x, device) for x in data]\n",
" return data.to(device, non_blocking=True)\n",
"\n",
"class DeviceDataLoader:\n",
" \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
"\n",
" def __init__(self, dl, device):\n",
" self.dl = dl\n",
" self.device = device\n",
"\n",
" def __iter__(self):\n",
" \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
" for b in self.dl:\n",
" yield to_device(b, self.device)\n",
"\n",
" def __len__(self):\n",
" \"\"\"Number of batches\"\"\"\n",
" return len(self.dl)\n",
"\n",
"def get_default_device():\n",
" return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def save_images(images, path, **kwargs):\n",
" grid = torchvision.utils.make_grid(images, **kwargs)\n",
" ndarr = grid.permute(1, 2, 0).to(\"cpu\").numpy()\n",
" im = Image.fromarray(ndarr)\n",
" im.save(path)\n",
" \n",
"def get(element: torch.Tensor, t: torch.Tensor):\n",
" \"\"\"\n",
" Get value at index position \"t\" in \"element\" and\n",
" reshape it to have the same dimension as a batch of images.\n",
" \"\"\"\n",
" ele = element.gather(-1, t)\n",
" return ele.reshape(-1, 1, 1, 1)\n",
"\n",
"def setup_log_directory(config):\n",
" '''Log and Model checkpoint directory Setup'''\n",
" \n",
" if os.path.isdir(config.root_log_dir):\n",
" # Get all folders numbers in the root_log_dir\n",
" folder_numbers = [int(folder.replace(\"version_\", \"\")) for folder in os.listdir(config.root_log_dir)]\n",
" \n",
" # Find the latest version number present in the log_dir\n",
" last_version_number = max(folder_numbers)\n",
"\n",
" # New version name\n",
" version_name = f\"version_{last_version_number + 1}\"\n",
"\n",
" else:\n",
" version_name = config.log_dir\n",
"\n",
" # Update the training config default directory \n",
" log_dir = os.path.join(config.root_log_dir, version_name)\n",
" checkpoint_dir = os.path.join(config.root_checkpoint_dir, version_name)\n",
"\n",
" # Create new directory for saving new experiment version\n",
" os.makedirs(log_dir, exist_ok=True)\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
"\n",
" print(f\"Logging at: {log_dir}\")\n",
" print(f\"Model Checkpoint at: {checkpoint_dir}\")\n",
" \n",
" return log_dir, checkpoint_dir\n",
"\n",
"def frames2vid(images, save_path):\n",
"\n",
" WIDTH = images[0].shape[1]\n",
" HEIGHT = images[0].shape[0]\n",
"\n",
"# fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
"# fourcc = 0\n",
" fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
" video = cv2.VideoWriter(save_path, fourcc, 25, (WIDTH, HEIGHT))\n",
"\n",
" # Appending the images to the video one by one\n",
" for image in images:\n",
" video.write(image)\n",
"\n",
" # Deallocating memories taken for window creation\n",
"# cv2.destroyAllWindows()\n",
" video.release()\n",
" return \n",
"\n",
"def display_gif(gif_path):\n",
" b64 = base64.b64encode(open(gif_path,'rb').read()).decode('ascii')\n",
" display(HTML(f'<img src=\"data:image/gif;base64,{b64}\" />'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-02T18:23:28.639407Z",
"start_time": "2023-02-02T18:23:28.624407Z"
}
},
"source": [
"# Configurations"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 9,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:54.785018Z",
"start_time": "2023-02-23T07:34:53.639004Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:36.712080Z",
"iopub.status.busy": "2023-02-22T16:01:36.711714Z",
"iopub.status.idle": "2023-02-22T16:01:36.720529Z",
"shell.execute_reply": "2023-02-22T16:01:36.719078Z",
"shell.execute_reply.started": "2023-02-22T16:01:36.712047Z"
}
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class BaseConfig:\n",
" DEVICE = get_default_device()\n",
" DATASET = \"MNIST\" # \"MNIST\", \"Cifar-10\", \"Cifar-100\", \"Flowers\"\n",
" \n",
" # For logging inferece images and saving checkpoints.\n",
" root_log_dir = os.path.join(\"Logs_Checkpoints\", \"Inference\")\n",
" root_checkpoint_dir = os.path.join(\"Logs_Checkpoints\", \"checkpoints\")\n",
"\n",
" # Current log and checkpoint directory.\n",
" log_dir = \"version_0\"\n",
" checkpoint_dir = \"version_0\"\n",
"\n",
"@dataclass\n",
"class TrainingConfig:\n",
" TIMESTEPS = 1000 # Define number of diffusion timesteps\n",
" IMG_SHAPE = (1, 32, 32) if BaseConfig.DATASET == \"MNIST\" else (3, 32, 32) \n",
" NUM_EPOCHS = 30\n",
" BATCH_SIZE = 128\n",
" LR = 2e-4\n",
" NUM_WORKERS = 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-02T18:24:36.837306Z",
"start_time": "2023-02-02T18:24:36.8273Z"
}
},
"source": [
"# Load Dataset & Build Dataloader"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 10,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:34:55.313081Z",
"start_time": "2023-02-23T07:34:55.291079Z"
},
"code_folding": [
0,
21,
37
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:37.597072Z",
"iopub.status.busy": "2023-02-22T16:01:37.596702Z",
"iopub.status.idle": "2023-02-22T16:01:37.608430Z",
"shell.execute_reply": "2023-02-22T16:01:37.607135Z",
"shell.execute_reply.started": "2023-02-22T16:01:37.597040Z"
}
},
"outputs": [],
"source": [
"def get_dataset(dataset_name='MNIST'):\n",
" transforms = TF.Compose(\n",
" [\n",
" TF.ToTensor(),\n",
" TF.Resize((32, 32), \n",
" interpolation=TF.InterpolationMode.BICUBIC, \n",
" antialias=True),\n",
"# TF.RandomHorizontalFlip(),\n",
" TF.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] \n",
" ]\n",
" )\n",
" \n",
" if dataset_name.upper() == \"MNIST\":\n",
" dataset = datasets.MNIST(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-10\": \n",
" dataset = datasets.CIFAR10(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Cifar-100\":\n",
" dataset = datasets.CIFAR10(root=\"data\", train=True, download=True, transform=transforms)\n",
" elif dataset_name == \"Flowers\":\n",
" dataset = datasets.ImageFolder(root=\"/kaggle/input/flowers-recognition/flowers\", transform=transforms)\n",
" \n",
" return dataset\n",
"\n",
"def get_dataloader(dataset_name='MNIST', \n",
" batch_size=32, \n",
" pin_memory=False, \n",
" shuffle=True, \n",
" num_workers=0, \n",
" device=\"cpu\"\n",
" ):\n",
" dataset = get_dataset(dataset_name=dataset_name)\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, \n",
" pin_memory=pin_memory, \n",
" num_workers=num_workers, \n",
" shuffle=shuffle\n",
" )\n",
" device_dataloader = DeviceDataLoader(dataloader, device)\n",
" return device_dataloader\n",
"\n",
"def inverse_transform(tensors):\n",
" \"\"\"Convert tensors from [-1., 1.] to [0., 255.]\"\"\"\n",
" return ((tensors.clamp(-1, 1) + 1.0) / 2.0) * 255.0 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize Dataset"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 11,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:35:01.528613Z",
"start_time": "2023-02-23T07:34:57.085306Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:39.192513Z",
"iopub.status.busy": "2023-02-22T16:01:39.191545Z",
"iopub.status.idle": "2023-02-22T16:01:39.267958Z",
"shell.execute_reply": "2023-02-22T16:01:39.266946Z",
"shell.execute_reply.started": "2023-02-22T16:01:39.192465Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
2024-04-09 09:31:18 +02:00
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
2024-04-08 11:37:01 +02:00
]
},
{
2024-04-09 09:31:18 +02:00
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 9912422/9912422 [00:00<00:00, 79520059.45it/s]\n"
]
2024-04-08 11:37:01 +02:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
2024-04-08 11:37:01 +02:00
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
2024-04-09 09:31:18 +02:00
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
2024-04-08 11:37:01 +02:00
]
},
{
2024-04-09 09:31:18 +02:00
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 28881/28881 [00:00<00:00, 61334528.52it/s]\n"
]
2024-04-08 11:37:01 +02:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
2024-04-08 11:37:01 +02:00
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
2024-04-09 09:31:18 +02:00
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
2024-04-08 11:37:01 +02:00
]
},
{
2024-04-09 09:31:18 +02:00
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1648877/1648877 [00:00<00:00, 38547540.84it/s]\n"
]
2024-04-08 11:37:01 +02:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
2024-04-08 11:37:01 +02:00
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
2024-04-09 09:31:18 +02:00
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
2024-04-08 11:37:01 +02:00
]
},
{
2024-04-09 09:31:18 +02:00
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4542/4542 [00:00<00:00, 12994903.66it/s]\n"
]
2024-04-08 11:37:01 +02:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
2024-04-08 11:37:01 +02:00
"\n"
]
}
],
"source": [
"loader = get_dataloader(\n",
" dataset_name=BaseConfig.DATASET,\n",
" batch_size=128,\n",
" device='cpu',\n",
")"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 12,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:00.707591Z",
"start_time": "2023-02-21T08:03:59.270574Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:39.573357Z",
"iopub.status.busy": "2023-02-22T16:01:39.572658Z",
"iopub.status.idle": "2023-02-22T16:01:39.920898Z",
"shell.execute_reply": "2023-02-22T16:01:39.919984Z",
"shell.execute_reply.started": "2023-02-22T16:01:39.573319Z"
}
},
"outputs": [
{
"data": {
2024-04-09 09:31:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA64AAAHiCAYAAADoA5FMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d2ykWXodDp/KOediFckq5szO3dOT42o2aLXa1WplQAusLNgyYEsyJFkWLBnGwrDhxc+ALViAISjYypKlzZN3ZrZ7prunpwPZzCwWq1isnHMO3x/zPXeLHPZMz86QVeTUAYie7Sa59633fe99wnnO4bRarRZ66KGHHnrooYceeuihhx566KFLwe30AnrooYceeuihhx566KGHHnro4YPQS1x76KGHHnrooYceeuihhx566Gr0Etceeuihhx566KGHHnrooYceuhq9xLWHHnrooYceeuihhx566KGHrkYvce2hhx566KGHHnrooYceeuihq9FLXHvooYceeuihhx566KGHHnroavQS1x566KGHHnrooYceeuihhx66Gr3EtYceeuihhx566KGHHnrooYeuRi9x7aGHHnrooYceeuihhx566KGr0Utce+ihhx566KGHHnrooYceeuhq8D/KN0ejUdy6dQutVuuw1tNRDA4OYnJyEnfv3kUoFOr0cg4FPB4PFy9eBI/Hw7Vr11Cv1zu9pEOB0WjE2bNnsbGxAbfb3enlHBpOnz4NrVaLa9euoVgsdno5hwKFQoGHHnoIoVAI9+7d6/RyDg3j4+NwOBy4efMmEolEp5dzKBAKhbh8+TJKpRJu3ryJZrPZ6SUdCux2O2ZnZ7G4uAi/39/p5RwKuFwuzp8/D7FYjGvXrqFarXZ6SYcCnU6H8+fPY3t7GxsbG51ezqFhbm4OJpMJ169fRy6X6/RyDgUymQwPPfQQ4vE47t692+nlHBqGh4cxOjqKW7duIRqNdno5hwKBQIBLly6hXq/jnXfeQaPR6PSSDgVWqxXz8/NYXV2F1+vt9HIOBRwOB+fOnYPBYHiwH2h9BPzwhz9s8fn8FpfLPZFfv/7rv95qNputr371qx1fy2F9yeXy1s2bN1tLS0sttVrd8fUc1tfnP//5Vr1eb/3e7/1ex9dyWF88Hq/1D//wD61oNNoaGRnp+HoO62tubq6VyWRaf/Inf9LxtRzm13/9r/+1ValUWk888UTH13JYXyaTqeV2u1uvv/56SyQSdXw9h/X1q7/6q61ms9n6xje+0fG1HNaXWCxu/fjHP265XK6W0Wjs+HoO6+upp55qVavV1n/+z/+542s5zK8///M/b6XT6dbs7GzH13JYX2NjY614PN7627/92xaPx+v4eg7r6/d///db9Xq99bnPfa7jazmsL61W21pZWWnduHGjJZPJOr6ew/r6pV/6pVaz2Wz9m3/zbzq+lsP6EggErZdeeumBc9GP1HFttVpoNpsntkpOnWS6zpMIqkqd9HtJ13WS7yXwk2e2dy+PPz4t9/LTcp3050m9xk/LWfJpiAuA967vpN/L9rPkpHbogE/HWfJp2X8+DWcJh8P5SEzej5S49tBDDz300EMPPfTQQw+fDDgcDvsCTn6RpIcePg56iWsPPfTQQw899NBDDz0cIXg8HkwmEwYGBjA7Owuz2Yxms4lIJIKXX34ZmUwGmUzmxOrK9NDDT4Ne4tpDDz30cILB4XAgEAjA4/EgFApZZb9araJer7Pq/kkVauuhhx566DZwuVxIJBLo9Xo4nU7Mz89jYGAA9XodXq8X169fR7lc7vQye+ih69BLXHvooYceTiiEQiGEQiFGRkag0+kwOTkJsVgMsViM1dVV+P1+5PN5FAoF7Ozs9OhpPfTQQw9HAJlMhvn5eZw5cwZPPvkkZmZmYDKZUKvVoNVqoVAokMlkOr3MHnroOvQS1x566KGHEwYulws+nw+DwQCNRsOCopmZGYhEIgiFQggEAmg0GmSzWWQyGZRKJZRKJRSLRdaJ7eFkg8vlgsfjQa/XQyAQIBaLoVardXX3ncvlsudXLBajVCqhWq2iVqv1ntkejgUkEgnUajWcTiecTicGBgagVqshEomQy+VQLBZRLpdRq9U6vdQePiXgcrkQi8WQSqVQKBTs7xuNBur1OnK5HGq1GiqVSsf32V7i2kMPPfRwwiAQCKBQKDA9PY3x8XE888wzsNlsmJycBI/HA4fDwdjYGPx+P7LZLCKRCDgcDoLBIHZ2dlAsFk+sL2cPP4FQKIREIsGFCxegUqnw+uuvI5PJIJvNdnpp9wWfz4dWq4VGo4HJZMLu7i5SqRRSqdSJVort4WSAy+VCp9PB4XDgsccew/T0NGZmZgAA9XodHo8H29vbiEajyOfzHU8Sevh0QCAQsHnriYkJJhRWLBaRzWaxsbGBZDKJeDze8cJ2xxNXPp8PkUgEsVgMoVCIQqGAer2OYrHY6aV9ZNCNpj+Bn0jM93DywOFw2JyKSCSCXC4Hj8cDj8fb8z30fdQVSKfTqFare+ZX6HtEIhEAoFqtdlT+fP8z/CCgOUqpVArgPfn2UqmESqVyKGvs4f5QKBQYGRnB2NgYxsbGoNfrIZPJUK1WIRAIWDdWKpWiUqnAZrOhVCrB5/Nhc3MTHo8H8Xgc6XT6U50M0HspFoshEonA4XDA4/EgkUiYhH8ul0O5XEa5XD42VGu6DpPJBLvdjueeew5arRYrKyuo1+tdm7hyuVxIpVJMT09jYGAAk5OTuHr1KlwuF/L5/Kf6We1GUFynUCggEAggkUjA4/HA5/MPPFeKxSKKxSIymQwqlcqJ6zhSR+vMmTMYGRnBxMQETCYTACCRSCCZTOKdd97B6uoqstnsiTo7FQoFxGIx1Go12zspJorH471CaQfA5/PB5/NhMpmg0+kwNzeHsbExnD59msWAtVoNyWQSb7/9Nra2trCysoJkMtnR+euuSFzlcjnUajVkMhlisRiKxSJKpdKxS/i4XC64XO6eBLbRaKDRaBy7a+nhw0F0TIVCAaVSCZPJxChs7c8A0fEKhQJ7rvP5/B7KBSW8crl8j1hOJwLh9mf4oxRehEIhRCIRdDodALD1n6TD97hALpdjcHAQg4ODGBgYgFwuh0AgYM8ch8OBRqOBVqsFABQKBTQaDVitVsjlclY0+bQnA+3vOBWmBAIBtFotOBwOms0mwuEwUqkUKzYdB5Bgl9FoxOjoKB5++GHodDr8yZ/8CVKpVKeXd19wOBxIpVKMjY1hamoKFy9eRDQaRTqdxvb2dqeX10Mb6F7JZDJYLBZIpVJoNBo2d7//XGm1WqyjQ+9Rpzs7nzToM5idncXo6CicTidkMhlarRYSiQT8fj8WFhbgdruRz+ePzX7yYeByuVCpVFAqlRgYGACXy0Wz2UShUEChUEA2m+1R/TsAKib19/fDZrPh4sWLmJmZwcWLF/d8XzQaRb1eB4/HQzgcRqFQ+HQnrhqNBqdOncL8/DwcDgd+9KMfYWdnB++++25Xz9m0QyAQQCgUwmazscoiVbRzuRyy2SwSiQSbxfk0gzqQIpGIqZxScEiJ/35QZa5araJUKnUsodsPiUQCjUaDhx9+GBMTE3j00UchlUpZ17S9a8nhcBCPx5FKpfCjH/0IXq8Xb775JiqVCur1OvR6PTQaDaamplCv17G4uIh0On2kQSRV3ywWCyQSCWq1GgqFAoLB4If+LIfDwejoKCwWCy5cuIBqtYpwOIx33nkHy8vLR7D6Htohk8kwMDCAWCyGeDyOV155BWKxGHNzc1Cr1TAYDGg2m+BwONDr9RCLxZicnMTQ0BAuXrwIq9WK27dv44UXXkAymTxxnY920P5D3SGJRAIAyGQyUKlU6O/vx8TEBPr7+yEWiyGXyzE9PY1arYZsNosXX3wRd+7cwerq6rE5s/h8PpRKJWZmZvDZz34WUqkU2WwW5XK5q88oLpcLuVyOM2fOYHR0FENDQ7BYLNBqteDzOx7OPDCooEnsB+C9c65YLB7b4J3OcKFQCLlcDpVKhfn5edhsNpw+fRoajQZms5mdMwchFAohEAjglVdewdbWFhYXF1GpVLr
2024-04-08 11:37:01 +02:00
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12, 6), facecolor='white')\n",
"\n",
"for b_image, _ in loader:\n",
" b_image = inverse_transform(b_image).cpu()\n",
" grid_img = make_grid(b_image / 255.0, nrow=16, padding=True, pad_value=1, normalize=True)\n",
" plt.imshow(grid_img.permute(1, 2, 0))\n",
" plt.axis(\"off\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Diffusion Process"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 13,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:35:01.574603Z",
"start_time": "2023-02-23T07:35:01.561607Z"
},
"code_folding": [
1
],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:40.386727Z",
"iopub.status.busy": "2023-02-22T16:01:40.385687Z",
"iopub.status.idle": "2023-02-22T16:01:40.397948Z",
"shell.execute_reply": "2023-02-22T16:01:40.396829Z",
"shell.execute_reply.started": "2023-02-22T16:01:40.386688Z"
}
},
"outputs": [],
"source": [
"class SimpleDiffusion:\n",
" def __init__(\n",
" self,\n",
" num_diffusion_timesteps=1000,\n",
" img_shape=(3, 64, 64),\n",
" device=\"cpu\",\n",
" ):\n",
" self.num_diffusion_timesteps = num_diffusion_timesteps\n",
" self.img_shape = img_shape\n",
" self.device = device\n",
"\n",
" self.initialize()\n",
"\n",
" def initialize(self):\n",
" # BETAs & ALPHAs required at different places in the Algorithm.\n",
" self.beta = self.get_betas()\n",
" self.alpha = 1 - self.beta\n",
" \n",
" self_sqrt_beta = torch.sqrt(self.beta)\n",
" self.alpha_cumulative = torch.cumprod(self.alpha, dim=0)\n",
" self.sqrt_alpha_cumulative = torch.sqrt(self.alpha_cumulative)\n",
" self.one_by_sqrt_alpha = 1. / torch.sqrt(self.alpha)\n",
" self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1 - self.alpha_cumulative)\n",
" \n",
" def get_betas(self):\n",
" \"\"\"linear schedule, proposed in original ddpm paper\"\"\"\n",
" scale = 1000 / self.num_diffusion_timesteps\n",
" beta_start = scale * 1e-4\n",
" beta_end = scale * 0.02\n",
" return torch.linspace(\n",
" beta_start,\n",
" beta_end,\n",
" self.num_diffusion_timesteps,\n",
" dtype=torch.float32,\n",
" device=self.device,\n",
" )\n",
" \n",
"def forward_diffusion(sd: SimpleDiffusion, x0: torch.Tensor, timesteps: torch.Tensor):\n",
" eps = torch.randn_like(x0) # Noise\n",
" mean = get(sd.sqrt_alpha_cumulative, t=timesteps) * x0 # Image scaled\n",
" std_dev = get(sd.sqrt_one_minus_alpha_cumulative, t=timesteps) # Noise scaled\n",
" sample = mean + std_dev * eps # scaled inputs * scaled noise\n",
"\n",
" return sample, eps # return ... , gt noise --> model predicts this)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample Forward Diffusion Process"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 14,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:15.117858Z",
"start_time": "2023-02-21T08:04:14.427843Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:41.581306Z",
"iopub.status.busy": "2023-02-22T16:01:41.580908Z",
"iopub.status.idle": "2023-02-22T16:01:42.154133Z",
"shell.execute_reply": "2023-02-22T16:01:42.153007Z",
"shell.execute_reply.started": "2023-02-22T16:01:41.581270Z"
}
},
"outputs": [],
"source": [
"sd = SimpleDiffusion(num_diffusion_timesteps=TrainingConfig.TIMESTEPS, device=\"cpu\")\n",
"\n",
"loader = iter( # converting dataloader into an iterator for now.\n",
" get_dataloader(\n",
" dataset_name=BaseConfig.DATASET,\n",
" batch_size=6,\n",
" device=\"cpu\",\n",
" )\n",
")"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 15,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-21T08:04:15.117858Z",
"start_time": "2023-02-21T08:04:14.427843Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:01:41.581306Z",
"iopub.status.busy": "2023-02-22T16:01:41.580908Z",
"iopub.status.idle": "2023-02-22T16:01:42.154133Z",
"shell.execute_reply": "2023-02-22T16:01:42.153007Z",
"shell.execute_reply.started": "2023-02-22T16:01:41.581270Z"
}
},
"outputs": [
{
"data": {
2024-04-09 09:31:18 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAF+CAYAAAAFumw3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3gc5bn+/1lptatdaVda9d6bZUmWezcuYIONAYMBAzHFQCAUA4EEckgIEHISOgkQ6qF8TTGmGwMuuNu4W7Ka1XvvdXdVduf3h3/ve1a2IZYwIZzMfV2+QFtmZ555Z+Yp93M/GkVRFFSoUKFChQoVKlSoUKHiLMLtp94BFSpUqFChQoUKFSpU/N+DGmioUKFChQoVKlSoUKHirEMNNFSoUKFChQoVKlSoUHHWoQYaKlSoUKFChQoVKlSoOOtQAw0VKlSoUKFChQoVKlScdaiBhgoVKlSoUKFChQoVKs461EBDhQoVKlSoUKFChQoVZx1qoKFChQoVKlSoUKFChYqzDjXQUKFChQoVKlSoUKFCxVmHGmioUKFCxVnEjh070Gg07Nix46xv++GHH0aj0Qx7bWhoiN/+9rdERkbi5ubGJZdcAkBvby833XQTISEhaDQa7r777rO+PzExMVx//fVnfbsqVKhQoeL/BtRAQ4UKFT8K3nrrLTQazWn/PfDAAz/17v3kONk+np6ehIWFsWjRIv7+97/T09NzRtt54403ePLJJ1m+fDlvv/0299xzDwD//d//zVtvvcWvfvUr1qxZw8qVK3/Mw/mX4nS2S0pK4o477qCpqemn3j0VKlSoUPH/Q/tT74AKFSr+b+PRRx8lNjZ22GtpaWk/0d78+0HYZ3BwkMbGRnbs2MHdd9/NM888w/r168nIyJCf/f3vf39KkLZt2zbCw8N59tlnT3l92rRp/PGPf/zR9r2oqAg3t58uXyVsZ7fb2bNnDy+99BJfffUVeXl5GI3Gn2y/VKhQoULFCaiBhgoVKn5UXHDBBUyaNOmsb7evrw8vL6+zvt1/BkVRsNvtGAyGs7K9k+3zu9/9jm3btnHhhRdy0UUXcfz4cflbWq0WrXb4bbu5uRlfX99Tttvc3ExqaupZ2cfvgl6v/1G3/8/garubbroJf39/nnnmGT7//HOuuuqq037np1o3KlSoUPGfCJU6pUKFip8U27ZtY/bs2Xh5eeHr68vFF1/M8ePHh31G9CYUFBRw9dVXY7FYmDVrFuvXr0ej0ZCTkyM/+/HHH6PRaLj00kuHbWPMmDFceeWV8u8333yT+fPnExQUhF6vJzU1lZdeeumU/YuJieHCCy9k06ZNTJo0CYPBwCuvvAJAbW0tl1xyCV5eXgQFBXHPPffQ39//g20yf/58/vCHP1BVVcU777xzih0AKisr0Wg0bN++nfz8fEkjEj0iFRUVfPnll/L1yspKSTmqrKwc9nun6yspKSnhsssuIyQkBE9PTyIiIlixYgVdXV3DbHNyj0Z5eTmXX345fn5+GI1Gpk2bxpdffnna31u3bh1//vOfiYiIwNPTkwULFlBaWvqD7AZQUVEBwPXXX4+3tzdlZWUsXrwYk8nENddcA5wIOO69914iIyPR6/UkJyfz1FNPoSjKKdt95513mDJlCkajEYvFwpw5c9i8efOwz3z99ddyHZtMJpYsWUJ+fv6wzzQ2NnLDDTcQERGBXq8nNDSUiy++eNj5OHz4MIsWLSIgIACDwUBsbCyrVq0atU1UqFCh4qeEWtFQoULFj4quri5aW1uHvRYQEADAN998wwUXXEBcXBwPP/wwNpuN559/npkzZ3L06FFiYmKGfe/yyy8nMTGR//7v/0ZRFGbNmoVGo2HXrl2SYrR7927c3NzYs2eP/F5LSwuFhYXccccd8rWXXnqJsWPHctFFF6HVavniiy+47bbbcDqd3H777cN+t6ioiKuuuopbbrmFm2++meTkZGw2GwsWLKC6uprVq1cTFhbGmjVr2LZt21mx28qVK/mv//ovNm/ezM0333zK+4GBgaxZs4Y///nP9Pb28pe//AU4EVCtWbOGe+65h4iICO699175+TPFwMAAixYtor+/nzvvvJOQkBDq6urYsGEDnZ2d+Pj4nPZ7TU1NzJgxA6vVyurVq/H39+ftt9/moosu4qOPPmLZsmXDPv/Xv/4VNzc37rvvPrq6unjiiSe45pprOHDgwBnvqyvKysoA8Pf3l68NDQ2xaNEiZs2axVNPPYXRaERRFC666CK2b9/OjTfeSGZmJps2beI3v/kNdXV1w2hojzzyCA8//DAzZszg0UcfRafTceDAAbZt28bChQsBWLNmDddddx2LFi3i8ccfx2q18tJLLzFr1iyysrLkOr7sssvIz8/nzjvvJCYmhubmZrZs2UJ1dbX8e+HChQQGBvLAAw/g6+tLZWUln3zyyajsoUKFChU/ORQVKlSo+BHw5ptvKsBp/wlkZmYqQUFBSltbm3zt2LFjipubm3LttdfK1/74xz8qgHLVVVed8jtjx45VrrjiCvn3hAkTlMsvv1wBlOPHjyuKoiiffPKJAijHjh2Tn7Naradsa9GiRUpcXNyw16KjoxVA2bhx47DXn3vuOQVQ1q1bJ1/r6+tTEhISFEDZvn37Gdnn0KFD3/kZHx8fZfz48fJvYQdXnHPOOcrYsWNP+W50dLSyZMmS0/5mRUXFsNe3b98+bJ+zsrIUQPnwww+/9xiio6OV6667Tv599913K4Cye/du+VpPT48SGxurxMTEKA6HY9jvjRkzRunv75ef/dvf/qYASm5u7vf+rjiOb775RmlpaVFqamqUtWvXKv7+/orBYFBqa2sVRVGU6667TgGUBx54YNj3P/vsMwVQHnvssWGvL1++XNFoNEppaamiKIpSUlKiuLm5KcuWLZP7LuB0OuXx+fr6KjfffPOw9xsbGxUfHx/5ekdHhwIoTz755Hce16effvpP14QKFSpU/JygUqdUqFDxo+LFF19ky5Ytw/4BNDQ0kJ2dzfXXX4+fn5/8fEZGBueddx5fffXVKdu69dZbT3lt9uzZ7N69G4Cenh6OHTvGL3/5SwICAuTru3fvxtfXd1gTumuPhai6nHPOOZSXlw+jBwHExsayaNGiYa999dVXhIaGsnz5cvma0Wjkl7/85Rnb5p/B29v7jNWnziZExWLTpk1YrdYz/t5XX33FlClTmDVrlnzN29ubX/7yl1RWVlJQUDDs8zfccAM6nU7+PXv2bOAE/epMcO655xIYGEhkZCQrVqzA29ubTz/9lPDw8GGf+9WvfnXKfrq7u7N69ephr997770oisLXX38NwGeffYbT6eShhx46peldUNi2bNlCZ2cnV111Fa2trfKfu7s7U6dOZfv27cCJ9abT6dixYwcdHR2nPR7Ra7NhwwYGBwfPyAYqVKhQ8e8MNdBQoULFj4opU6Zw7rnnDvsHUFVVBUBycvIp3xkzZgytra309fUNe/1k9So44Zw2NDRQWlrKt99+i0ajYfr06cMCkN27dzNz5sxhzuLevXs599xzZW9IYGAg//Vf/wVw2kDjZFRVVZGQkHDKXIvTHc9o0dvbi8lkOmvbO1PExsby61//mtdff52AgAAWLVrEiy++eIpdTkZVVdV3nk/xviuioqKG/W2xWAC+0xE/GSKI3b59OwUFBZSXl58SEGq1WiIiIk7Zz7CwsFNse/J+lpWV4ebm9r1N9SUlJcCJ/pDAwMBh/zZv3kxzczNwonH+8ccf5+uvvyY4OJg5c+bwxBNP0NjYKLd1zjnncNlll/HII48QEBDAxRdfzJtvvnlW+n5UqFCh4qeAGmioUKHiZ4PTKT2J7PmuXbvYvXs3EyZMwMvLSwYavb29ZGVlyWw5nHAgFyxYQGtrK8888wxffvklW7ZskTMonE7nP/3dHxu1tbV0dXWRkJBw1rZ5clAk4HA4Tnnt6aefJicnh//6r//CZrOxevVqxo4dS21t7VnbH3d399O+rpymIft0EEHs3LlzGTNmzGmldvV6/Y8qwSvWypo1a06p3G3ZsoXPP/9cfvbuu++muLiYv/zlL3h6evKHP/yBMWP
2024-04-08 11:37:01 +02:00
"text/plain": [
"<Figure size 1000x500 with 12 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x0s, _ = next(loader)\n",
"\n",
"noisy_images = []\n",
"specific_timesteps = [0, 10, 50, 100, 150, 200, 250, 300, 400, 600, 800, 999]\n",
"\n",
"for timestep in specific_timesteps:\n",
" timestep = torch.as_tensor(timestep, dtype=torch.long)\n",
"\n",
" xts, _ = forward_diffusion(sd, x0s, timestep)\n",
" xts = inverse_transform(xts) / 255.0\n",
" xts = make_grid(xts, nrow=1, padding=1)\n",
" \n",
" noisy_images.append(xts)\n",
"\n",
"# Plot and see samples at different timesteps\n",
"\n",
"_, ax = plt.subplots(1, len(noisy_images), figsize=(10, 5), facecolor='white')\n",
"\n",
"for i, (timestep, noisy_sample) in enumerate(zip(specific_timesteps, noisy_images)):\n",
" ax[i].imshow(noisy_sample.squeeze(0).permute(1, 2, 0))\n",
" ax[i].set_title(f\"t={timestep}\", fontsize=8)\n",
" ax[i].axis(\"off\")\n",
" ax[i].grid(False)\n",
"\n",
"plt.suptitle(\"Forward Diffusion Process\", y=0.9)\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 16,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.619395Z",
"start_time": "2023-02-13T13:01:00.605395Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:43.160882Z",
"iopub.status.busy": "2023-02-22T16:01:43.160156Z",
"iopub.status.idle": "2023-02-22T16:01:43.170501Z",
"shell.execute_reply": "2023-02-22T16:01:43.169283Z",
"shell.execute_reply.started": "2023-02-22T16:01:43.160846Z"
}
},
"outputs": [],
"source": [
"# Algorithm 1: Training\n",
"\n",
"def train_one_epoch(model, sd, loader, optimizer, scaler, loss_fn, epoch=800, \n",
" base_config=BaseConfig(), training_config=TrainingConfig()):\n",
" \n",
" loss_record = MeanMetric()\n",
" model.train()\n",
"\n",
" with tqdm(total=len(loader), dynamic_ncols=True) as tq:\n",
" tq.set_description(f\"Train :: Epoch: {epoch}/{training_config.NUM_EPOCHS}\")\n",
" \n",
" for x0s, _ in loader:\n",
" tq.update(1)\n",
" \n",
" ts = torch.randint(low=1, high=training_config.TIMESTEPS, size=(x0s.shape[0],), device=base_config.DEVICE)\n",
" xts, gt_noise = forward_diffusion(sd, x0s, ts)\n",
"\n",
" with amp.autocast():\n",
" pred_noise = model(xts, ts)\n",
" loss = loss_fn(gt_noise, pred_noise)\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
" scaler.scale(loss).backward()\n",
"\n",
" # scaler.unscale_(optimizer)\n",
" # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
"\n",
" loss_value = loss.detach().item()\n",
" loss_record.update(loss_value)\n",
"\n",
" tq.set_postfix_str(s=f\"Loss: {loss_value:.4f}\")\n",
"\n",
" mean_loss = loss_record.compute().item()\n",
" \n",
" tq.set_postfix_str(s=f\"Epoch Loss: {mean_loss:.4f}\")\n",
" \n",
" return mean_loss "
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 17,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:29.926465Z",
"start_time": "2023-02-23T07:36:29.906486Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:43.842403Z",
"iopub.status.busy": "2023-02-22T16:01:43.841296Z",
"iopub.status.idle": "2023-02-22T16:01:43.857743Z",
"shell.execute_reply": "2023-02-22T16:01:43.856757Z",
"shell.execute_reply.started": "2023-02-22T16:01:43.842351Z"
}
},
"outputs": [],
"source": [
"# Algorithm 2: Sampling\n",
" \n",
"@torch.inference_mode()\n",
"def reverse_diffusion(model, sd, timesteps=1000, img_shape=(3, 64, 64), \n",
" num_images=5, nrow=8, device=\"cpu\", **kwargs):\n",
"\n",
" x = torch.randn((num_images, *img_shape), device=device)\n",
" model.eval()\n",
"\n",
" if kwargs.get(\"generate_video\", False):\n",
" outs = []\n",
"\n",
" for time_step in tqdm(iterable=reversed(range(1, timesteps)), \n",
" total=timesteps-1, dynamic_ncols=False, \n",
" desc=\"Sampling :: \", position=0):\n",
"\n",
" ts = torch.ones(num_images, dtype=torch.long, device=device) * time_step\n",
" z = torch.randn_like(x) if time_step > 1 else torch.zeros_like(x)\n",
"\n",
" predicted_noise = model(x, ts)\n",
"\n",
" beta_t = get(sd.beta, ts)\n",
" one_by_sqrt_alpha_t = get(sd.one_by_sqrt_alpha, ts)\n",
" sqrt_one_minus_alpha_cumulative_t = get(sd.sqrt_one_minus_alpha_cumulative, ts) \n",
"\n",
" x = (\n",
" one_by_sqrt_alpha_t\n",
" * (x - (beta_t / sqrt_one_minus_alpha_cumulative_t) * predicted_noise)\n",
" + torch.sqrt(beta_t) * z\n",
" )\n",
"\n",
" if kwargs.get(\"generate_video\", False):\n",
" x_inv = inverse_transform(x).type(torch.uint8)\n",
" grid = torchvision.utils.make_grid(x_inv, nrow=nrow, pad_value=255.0).to(\"cpu\")\n",
" ndarr = torch.permute(grid, (1, 2, 0)).numpy()[:, :, ::-1]\n",
" outs.append(ndarr)\n",
"\n",
" if kwargs.get(\"generate_video\", False): # Generate and save video of the entire reverse process. \n",
" frames2vid(outs, kwargs['save_path'])\n",
" display(Image.fromarray(outs[-1][:, :, ::-1])) # Display the image at the final timestep of the reverse process.\n",
" return None\n",
"\n",
" else: # Display and save the image at the final timestep of the reverse process. \n",
" x = inverse_transform(x).type(torch.uint8)\n",
" grid = torchvision.utils.make_grid(x, nrow=nrow, pad_value=255.0).to(\"cpu\")\n",
" pil_image = TF.functional.to_pil_image(grid)\n",
" pil_image.save(kwargs['save_path'], format=save_path[-3:].upper())\n",
" display(pil_image)\n",
" return None"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 18,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:07.313353Z",
"start_time": "2023-02-23T07:36:07.307373Z"
},
"code_folding": [],
"execution": {
"iopub.execute_input": "2023-02-22T16:01:44.742690Z",
"iopub.status.busy": "2023-02-22T16:01:44.742000Z",
"iopub.status.idle": "2023-02-22T16:01:44.748704Z",
"shell.execute_reply": "2023-02-22T16:01:44.747497Z",
"shell.execute_reply.started": "2023-02-22T16:01:44.742652Z"
}
},
"outputs": [],
"source": [
"@dataclass\n",
"class ModelConfig:\n",
" BASE_CH = 64 # 64, 128, 256, 512\n",
" BASE_CH_MULT = (1, 2, 4, 8) # 32, 16, 8, 4 \n",
" APPLY_ATTENTION = (False, False, True, False)\n",
" DROPOUT_RATE = 0.1\n",
" TIME_EMB_MULT = 2 # 128"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 19,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.588388Z",
"start_time": "2023-02-13T13:01:00.344403Z"
},
"code_folding": [
0,
13,
23
],
"hide_input": false
},
"outputs": [],
"source": [
"model = UNet(\n",
" input_channels = TrainingConfig.IMG_SHAPE[0],\n",
" output_channels = TrainingConfig.IMG_SHAPE[0],\n",
" base_channels = ModelConfig.BASE_CH,\n",
" base_channels_multiples = ModelConfig.BASE_CH_MULT,\n",
" apply_attention = ModelConfig.APPLY_ATTENTION,\n",
" dropout_rate = ModelConfig.DROPOUT_RATE,\n",
" time_multiple = ModelConfig.TIME_EMB_MULT,\n",
")\n",
"model.to(BaseConfig.DEVICE)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig.LR)\n",
"\n",
"dataloader = get_dataloader(\n",
" dataset_name = BaseConfig.DATASET,\n",
" batch_size = TrainingConfig.BATCH_SIZE,\n",
" device = BaseConfig.DEVICE,\n",
" pin_memory = True,\n",
" num_workers = TrainingConfig.NUM_WORKERS,\n",
")\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"\n",
"sd = SimpleDiffusion(\n",
" num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n",
" img_shape = TrainingConfig.IMG_SHAPE,\n",
" device = BaseConfig.DEVICE,\n",
")\n",
"\n",
"scaler = amp.GradScaler()"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 20,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:01:00.603387Z",
"start_time": "2023-02-13T13:01:00.590387Z"
}
},
2024-04-09 09:31:18 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Logging at: Logs_Checkpoints/Inference/version_0\n",
"Model Checkpoint at: Logs_Checkpoints/checkpoints/version_0\n"
]
}
],
2024-04-08 11:37:01 +02:00
"source": [
"total_epochs = TrainingConfig.NUM_EPOCHS + 1\n",
"log_dir, checkpoint_dir = setup_log_directory(config=BaseConfig())\n",
"\n",
"generate_video = False\n",
"ext = \".mp4\" if generate_video else \".png\""
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 21,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T12:45:46.057703Z",
"start_time": "2023-02-13T12:45:39.770695Z"
},
"_kg_hide-output": true,
"hide_input": false
},
2024-04-09 09:31:18 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 1/30: 100%|██████████| 469/469 [01:06<00:00, 7.08it/s, Epoch Loss: 0.0557]\n",
"Train :: Epoch: 2/30: 100%|██████████| 469/469 [01:00<00:00, 7.73it/s, Epoch Loss: 0.0287]\n",
"Train :: Epoch: 3/30: 100%|██████████| 469/469 [01:01<00:00, 7.64it/s, Epoch Loss: 0.0241]\n",
"Train :: Epoch: 4/30: 100%|██████████| 469/469 [01:00<00:00, 7.73it/s, Epoch Loss: 0.0221]\n",
"Train :: Epoch: 5/30: 100%|██████████| 469/469 [01:01<00:00, 7.68it/s, Epoch Loss: 0.0209]\n",
"Sampling :: 100%|██████████| 999/999 [00:14<00:00, 67.20it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwCh8Ufij4y8OfEbVdK0rWPs9lB5PlxfZoX27oUY8shJ5JPWuP8A+F2fEL/oYP8AySt//jdL8bP+Su65/wBu/wD6IjpngH4fx69bza/rtyLDw5ZnM0zHBlI/hX/GgDUtvib8WbzSLnVrfUZ5NPtjia4Wwt9ifU7Kzv8AhdnxD/6GD/ySt/8A43XQD4taK80/hybw9EPBZHlR28TFZQM/6zcDkk9cVz3xK0HQbSPSNb8KQlNDv4SFy7EiVWIYHcSQelAHYeIvij4ysPhz4L1a21nZfal9u+1S/ZYT5nlzBU4KYGAccAZ71x//AAu34hf9DD/5JW//AMbo8W/8ki+HX/cS/wDR61zfhLw7N4q8S2mlRHYsh3TSdo4xyzflQB654M1/4reLbV9Rl8Uw6ZpSZH2y5srfax9ANgzXNeIfir8Q9B1680v/AISlLkW77RPHY24WQYByPkPrWJ8RPFKalfRaDpMzjw/pQ8i1jBx5hHV29STn+nWuI6cHIoA9v8O/FHxlf/DnxnqtzrG+9037D9kl+ywjy/MmKvwEwcgY5Bx2rkP+F2fEP/oYf/JK3/8AjdHhH/kkXxF/7hv/AKPauAoA7/8A4XZ8Q/8AoYP/ACSt/wD43R/wuz4hf9DB/wCSdv8A/G6raN4T0/TtC/4SPxebmGwk+WytISFlum6556Jxgn3rbu9PsdB+Gmoavf6VbQ3mvSpFptqyZa2hXkuCctz65zQBu/C/4o+MvEfxG0rStV1j7RZT+d5sX2WFN22F2HKoCOVB4Nch/wALs+IQ/wCZg/8AJK3/APjdHwT/AOSu6H/28f8AoiSuAALMABkk4FAHs/hPxX8Y/GiTyaPq8bRwEK8kttbou7rgHy+uKw9V+LPxJ0bV7rTbnxChmtpDG5jtLdlJHofL5Fdrrmr3Hwc+G+kaTpNuF1bUU824uZFyEcj5se46AV4jocWn3+vW8Os3E0VrO+2WeMbnTP8AFjvzzigDrf8AhdnxD/6GD/yTt/8A43XX/FH4o+MvDnxG1XStK1j7PYweT5cX2aF9u6FGPLISeST1ryLXtJk0LX77SpH8xrWdovMAxuAOM4966742f8ld1z/t3/8AREdAB/wuz4h/9DB/5JW//wAbrV0n4k/FbWrLUbyy1ndbafAZ7iVrO3Cqo7D93yT6V5WASQAOScCvWPH9wPCPgDw74R0+R4pLy1F9qDKApm39FYj0II+gFAGN/wALs+IWP+Rg/wDJK3/+N12HiH4oeMrH4deDNVttY2X2o/bvtUv2aE+Z5cwVOCmBgHHAGe+a8Qr0DxZ/ySH4d/8AcS/9HrQAf8Lr+IX/AEMH/knb/wDxFH/C6/iF/wBDB/5JW/8A8RXUeBvhJ4e1fwTB4n1/VLqCFi7usbKqKinHJIJzkHvVS+1P4L2g+zW2g392U63AmkXd/wCPYx+FAGF/wuv4hf8AQwf+Sdv/APG66/w98UPGN/8ADnxnqtxrG++077D9ll+zQjy/MmKvwEwcgY5Bx2qP4geG/Ca/Cew8QaHosmmySXCrH5rNvZSGznJORx1rlPCX/JI/iJ/3Df8A0e1AB/wuv4hf9DB/5J2//wARR/wuv4hf9DB/5JW//wAbrM8DfD/VvHeoSQ2G2G2hGZrqQZRD2HuT6fWu51DQvhP4MaSx1Oe81vUUG2RYpCojfH+wRj8c0Acz/wALr+IX/Qwf+SVv/wDG66/4X/FDxl4i+I2laVqusfaLGfzvMi+ywpu2xOw5VARyAetVtb8H+CNW+Fd94u8O2t1YywMqoksrMCwZQwIYnPBrmfgp/wAld0P/ALb/APoiSgBv/C7PiH/0MP8A5JW//wAbpR8a/iIzADXySegFlb8/+Q68/r0/whBpngrwWfG2pW0V7qV07Q6VaygFVIyDIR7GgAb4n/FlLH7a99dra5x5p02Hbn6+XWf/AMLs+If/AEMP/klb/wDxuvR/BvjTxZfaTqvinxdfKnhtLdxFbGFE+0ORgBONxH48189MQXJHQmgD78ooooA+RfjNF5/xk1iLcF3vbLk9swRitj4ySyeHrLQfBNpIg0+ztFmkEYx5kpJyx/n+NYvxocx/GHWpFxlWt2H4QR10p+NOhanFD/wkHgq1vrmJVHnKUBOBjnKk49s0AeY+HfDGq+KNUgsdNtZJGlbBk2nYg7knpgV1HxOvtKtpNM8K6LIJrTRofLlnB4lmJJcjtjNT+J/i7qmpWzab4eto9B0hlINvahQzZ6ksAMfhivOMk+9AHf8Ai3/kkPw6/wC4l/6PWovhbpUmravq0I1GTToRpkjTXUfVE3Ju/MZFS+Lf+SQ/Dr/uJf8Ao9a5jwz4m1Hwpqw1HTmj8woY5I5V3JKh6qw7g0Ad1deOvB3hpYrLwt4UsNQaA/Pf6nHvMp7kLnI/Or3jCa08a/CaHxZDoVlp95aX/wBmf7GmwFNuST6jO3GfesRfHXg673S6n8PbJrj7wazuZIU3e65II9qxPE/jq98Q2MGmQWsGm6RbkmOytQQpPq3PJ96ANXwj/wAkj+Iv/cN/9HtXK+HrvTLHXbW61ezkvLKJtzwRtt3kDgE+mcZ9RXVeEf8AkkXxF/7hv/o9q4E+tAHtennwt8T/ABlBLLba0wiXzZ/MnUWtrEuPlA25CngcEda4r4n+MU8XeJFFpGI9NsFNvaKB/Dnk/jgflWvq2qWngz4b23hzTJkbV9XVbnUp4nz5cf8ADHkd/wD69eZ5oA774Kf8ld0P/t4/9ESVyGhadcav4g0/TrVlW4ubhI42c4AYngmuv+Cf/JXdD/7eP/RElcJb3M9ndRXNtI8U8Th45EOCrDkEGgD6s8ceArnxl4n0D7WVfRLONjdKrEO7dlH16ZqPw5pOnjxSlvpHgmzttEtdwGo3MIMjuBwUJySM9zmvFE+NnjxAAdXRgFwA1tGfxzt5q63x48ZNGUZ7TmExZEIHzf3+O/t09qAMP4rOJPibrTKE/wBeR8hyD7/Wrvxs/wCSu65/27/+iI64KWV55nlldnkkYs7Mckk9ya7342f8ld1z/t3/APREdAHC2sUk95BFFkyPIqrj1J4r6N8fW3w+aTTtH8UyyJr5tIoPtUDMTEQOC3OMZz2r590LU/7F16x1PylmFrMspjPRsHpXVfErxFofiPxVF4h0Sa78+dUa4huI8LG6gAbT3HFAHM+JNBufDOv3ekXZVpbd8B16Op5Vh9QQfxrrPFn/ACSH4d/9xL/0etcfres3viDWbrVtQc
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAACLwElEQVR4Ae3dZbhl2VXG+wR3d7+4uxMpCU5wdw0Q3EKAdLq7GgmEJHiQECQ4wd1KcXd3d3e9v7X++4wze1d1p/Pc536r+WHWmGO+4x1j+lxr733q3v/3f/93r5vpZg/c7IEnpwee4skB38Te7IGbPbD1wFPVDU//9E//3//93//7v//r8HmKp3iKe9/73gQ5jWInklxSZKIqQzllMmH0ZLZP/dRPjfapnuqp/vM//zOryFf8//zP/yhCyllJ1TLHJgE8zdM8zX/913/RjzmlqsIL/5RP+ZTjRVVxqiLHGY94IGmGPxcrPpMtlP/7v8Bp2ILJ0dKTkSTniAsm5BqelZyGHn5jPBHyGCflxJOLmJ/u6Z6OoJY75Krk5GrlNFVxUUpJzvsISJLjUQtJpqzzaSQwOVqCXLfLM5QPSZ0QcjPb2eQI99LWZBjMzP/1X/9V8Wmf9mkLG4mYwYatFrGloQeWFOvGgqFRhSF+OYCi2cUK5j/+4z/CBNgoTmYUDDBMVgHyLlermWJIA7ZGAkwvx8ZdmEPoLKkY7L7uNMYMIlIVRSzwtVYtWdXKEBVNVvUypUSJgSCXuG7q0Bc6ZWHAaEDyaCJcSVhhKAC11pg8w7wUNiUrcikMW0UypyNsYe1FmjHJPCu5sGtL+iKHz1ZxIiy2iWTlHAbKgTFcvRQJDYzU8BfVVKnNLwA575CCHAwT+rolsNo2Iz1GE6Cqop1WxE+ZwEsBNK2Z0Jfo0VabOX0Tmp68teGkl8BwKuYufmCG6eNhNcjCY5I7uRapZaIDFXmpjRP8gIuzolwxjwlImLNal9B4KZJqNy9TgUjKOEGep1GmyQ2ZLZlQswmQqCmFPjGF3JztK2SKNJnUR40BDXMkGKQP/MAPfL3Xe73nfd7n/cd//Mdv/MZvfOITn1jAmRRAEy69nBVzQuTp4cFWq1ohYMq1gRMeJVmKDS15vLDSv02RGOqKCT7OcspneqZnwvAv//IvSODjoQFQBJCGnxxglGDCqAlkyMDJwWiiajFgUCvCijktYHhFtXmJSv9D0leUTzHweFElmKgwTJD0kIppmsf1Xo7InEoBFJmkAWjghFFtsPhpFJFLITNMI1fMBVheFOmHZDCRB2bF6RDmhflg1DJUhNHerDY56chruIwzI09AwIpSLvXO6okeQMJMnjxHM2yqalWAvEyIfL3US73UW7zFW7zne77nK73SK+G3bH7hF34BOAx+GPmY5BEAuNbKuStOsNkayWDh5WDFT59hVTHzImWCCpKtRO7WATyAGBTDnD9//v73v7/++ed//ueXeZmX+czP/Mzf/u3fZlgtQXjNrQm12sFARs4voTjlJnQyjwIID0BARfOgBz2I/CVf8iX4FUvYGHJaM2tULgCGP31FuSp4bPQ5ykROLwwCTAByAkNyxQRIykySs6p2tSKDpS+HzLsqLSIDlANgAxh8MUNKlFWlVFz1eQlWniYT/BLDeowg5eV0ASmzmV5QBGJGiSUi1BER7nvf+zoH/u7v/u5Hf/RHf/VXfxV1pABqyWMVJ8KGsKrBEPKVORnDq7zKq3zAB3zAW7/1Wz//8z//3//93//TP/2Ts/t5nud56m4mNQ8SXmIbD7mAFQtJGGNFUKuqVuymmzsCJRIAgjitsdVFsUGmhEdLbggVAdSyjUHVAx/4QEfl/e53v2d4hmeghPn2b//23/md3ynUwE1iJDOV4+EdXgKjkTDwWOSs2gJoYCglQiSEd3/3d7/99tt/8id/8su+7MtaYJQlSGwFX1T0okWoWOvUNlLAcdKrzQv8RFUAclVxlrMiBDs4PumiwPjzzhAsjMAq5otcAiCsYStS5qKqIZkqAtqKbAFioMTPVlGqaGo93/M9n+If//Ef/83f/A0rAK1mpS0V1VIW6umyyXF1kcqlfMuT8/RyL/dyH//xH//mb/7m//Zv//ZRH/VRv/Irv9LwwOBpMhUx34ScqSqUijEHGwzlS7/0S9ssP+iDPsiV5jd/8ze/53u+5zd+4zc0zBL1iPzv//7vmYukYHJa/KpqJ8JcV4QEKx9ftUieJgFG0dsFhGTzSTFy3a1p4x2eXi5RckeA1xWWyju8wzu88Ru/sXMm/KVLl5ruvbcAxiaXANBO2OHLucbJC1hehARZMQ0kgR6/qrNnz370R3+0a63ee8ZnfEabzhDCcBpbLaqKrKpOkDd8qsiixSlVS4CUck1ZqhVqWXFBSViRikzqQLUte35DBq5YWwJHiwdA7XRC3ukpkwFKU3WiOPzLEe8KhGjNpf9nTy/0Qi/04i/+4q/1Wq+l9od/+Id//Md//Bd/8Rf/4i/+Qj9AciHV5In2sFd5KNz6Zn/MBSroclw0dTcZ0rR2wqDAdfHiRWcC/aTNyUmKQQlYQkIjndQf5lyaWmXCPeQhDxHuP/zDP3zzN3/za7/2awMPOUGQbY1HJKuXvSlPhbAWEZgAELIiSClp0JI9fphnXiq+6Iu+6Mu//Ms/y7M8i2tYVsCodqNtTkjwukLkKYe2UF/xFV/xu77ru8BMZW3xKult3/ZtjdPW+H2NsRpmVJGUA/C1O9k2+0mYMwFIqCpZjudVX/VVrU87i9vgh37ohzIZRwDwE3O2eaktNIFhJHKYwmMuFXy0xSMHfvZnf3ZdxyRDQpyRTFuigiEkc0GgyRe8hFPiS0rPdfrd7qkpVxJV9HnJKga5qj3wLRsl8+d4jud4szd7s+/4ju+wL2RYbkq72jziEY8wDSaGODEQDm2Jy+hWJ1ddMTM5KDqCOe2y7nnDJSpDS8g9yvTScWCmiO0t6swZijIw/Si1c5wOXu1LvMRLvPIrv7LG/NIv/dIHf/AHW/SFpKoYgBPSFDbl9EsamFHSkOElsqo0thPdKteJRp1r7pzRj3rUo+wF7/Ve7/Wt3/qtIoFhLmUbj3x25eKpW4F1gi1fKwhMrBmXWLMZPr8wCQGKp/jTjwtdx4prSFbpKRVzmhJA1XM/93M/+tGPfs3XfM2//uu//pRP+ZQv//Ivx1bKECxflBO8qswJ9LUCTEIup8mL0ScbawNXt0T7gi/4gp5CjftP/dRPaWmNim31jqqZAMCjKphcKErjazwCMIEM4GpT2DN58iIHkE9CW4/JsdFHLifrqLd7u7e74447tIjGjeAv//Ivtcu+xpC7+9znPmTugDEXKiGqzUsBiYNKqmEEVVEkyJ/t2Z7N6f/nf/7n6OSf/MmffPXqVT1okP7wD//wd3/3d12lvumbvskqAuZp59t6HGdeopVTSlMEllhpxod8yId4AHCaeRPQZg9WbXiGUxwSmtlvagunGcpHs3k9Sdyp4tH172EPe9jP/MzPuBlSit8F1xsIi9ZyokGVEUFaewwzDR45ucDe4A3e4Ed+5EfAjLGL5eMe9zgXAFVgEiqcVulzPudzFoOq+ANQKuYFLTkXKya5WlYEyUL9q7/6K4Yf/uEfbrDimdpcgyVg0L15Idi8JVVyVdoiQgySgbCD2IB/4Ad+wFgzMRd/4id+4vKeDJbJYOZR/tqv/ZoXOZrDCwYpwrxEO/kW8Z6CEW3wL/mSL/mu7/qu3pf+wR/8we/93u/Za376p3/693//971EtZd96qd+6uu+7usKCa20ezj0P9q88C6ppZHI8rXI3L7/R3/0R0bHS6ZP//RPdxFooF3ZPumTPgmP3dOsYJUv+XDm5bBsdhcbe27ylA2ZYDLpOB3EmSadO3fuuZ7ruTzvfvVXf7U7NC57qqf23/qt39LL42Mizln80cJIOktOU/rYj/3Yv/3bv/3lX/5lmwFNVayCEUrpM88FWVVeaKQGI7mcJpgiMH5Pgd/wDd9gKrjSfNzHfZxNiPKd3umd8Pz6r/+6uc4vpFzKaqZaxWr3KbfNtlrxFV/xFXoJCVot0nVgRYsn147rL/zCL7z11lstHoHFj0ECoKktTeX02aoNU16RXwNvLFzPGL7pm74pDY9hMqw4jhiS86JR5HIC29d5nde5ePGiW/6P/diPebx0YbZXStZGJmQ7ixPG6Dt5JNu2CeB
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 6/30: 100%|██████████| 469/469 [01:00<00:00, 7.71it/s, Epoch Loss: 0.0199]\n",
"Train :: Epoch: 7/30: 100%|██████████| 469/469 [01:00<00:00, 7.69it/s, Epoch Loss: 0.0195]\n",
"Train :: Epoch: 8/30: 100%|██████████| 469/469 [01:01<00:00, 7.62it/s, Epoch Loss: 0.0190]\n",
"Train :: Epoch: 9/30: 100%|██████████| 469/469 [01:02<00:00, 7.51it/s, Epoch Loss: 0.0187]\n",
"Train :: Epoch: 10/30: 100%|██████████| 469/469 [01:03<00:00, 7.44it/s, Epoch Loss: 0.0186]\n",
"Sampling :: 100%|██████████| 999/999 [00:14<00:00, 66.95it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxR8ZeHPiNqulaTrP2exg8ny4vssL7d0KMeWQk8knrXH/8Ls+If/Qw/wDklb//ABuj42f8le1z/t3/APREdcLa2017dxWtuheaVwiKOpJOBQB3X/C7PiH/ANDD/wCSVv8A/G6P+F2fEP8A6GH/AMkrf/43WV488PWHhbVbTR7aZpb2C1T+0G3ZUTn5iq+wBArN1zw/PoUOmPNJv+32i3S4XG0Enj36UAer+Ifij4ysfhx4L1a21ny77Uvt32uX7LCfM8uYKnBTAwDjgDPeuVtvjH8Sry4S3ttbkmmc4VI7GAkn6eXUXiz/AJJD8Ov+4l/6ULXGaZqt/o14LvTbuW1uApUSxMVYA9eaAPSdY+I3xe8P+WdWvbmzEgypl0+3Gf8AyH7Gsn/hdnxD/wChh/8AJK3/APjdL4S1K/1rRPFNprN1Jeactj9qZ7uRpDFMrAIykngncw968/7UAe3eHvij4yvvhx401a51jffab9h+yS/ZoR5fmTFX4CYOQMcg47Vx/wDwu34h/wDQw/8Aklb/APxujwn/AMkh+Iv/AHDf/R7V5/QB6K/xj+JUdvHcPrbrDKSEc2NuAxHXH7vmov8AhdvxD/6GH/ySt/8A43Vz4tWTaPYeEdFPIs9NyXAwGLkE1wkWgavPpc2pxabcvYwkCScRnauenNAHr/wt+KXjPxH8R9K0rVtZ+0WM/neZF9lhTdthdhyqAjlQetch/wALs+If/Qw/+SVv/wDG6T4Jf8le0L/t4/8ASeSuAoA9A/4XZ8Q/+hh/8krf/wCN0f8AC7PiH/0MP/klb/8AxusqwvfBEumWNvqmkapFeJuFxdWN0AJOeDsdWGQPTFdBd+D/AAZN4C1TxLpOpauRaskMaXaIoaVu3A5H0oAqf8Lt+If/AEMP/klb/wDxuuw+KPxR8ZeHfiNqulaTrH2eyg8ny4vs0L7d0KMeWQk8k968P716B8bP+Sva7/27/wDoiOgA/wCF2fEP/oYf/JK3/wDjdXbr4sfEyysbS7n8QIiXQLRIbS337QcbiPL4B7V5/pEC3WtWUDkBZJ0U5GeCRWx8QBGvj/W44FCwxXTxRqDnaqnaAPwHTtQBt/8AC7fiH/0MH/klb/8Axuuv8RfFLxlYfDnwXq1trGy91L7d9rl+ywnzPLmCpwUwMA9gM968Qr0Dxd/ySH4df9xL/wBHrQAn/C7fiH/0MP8A5JW//wAbrQ034n/FnWJFj07ULm6Zs48rToDnH/bOvP8AQ72003WrW8vrFL62hfc9s5wsnHQ/jivfNOXxtrMVprd5JD4V8L26m5EGmrtklXjA2D72QB149qAOC1b4pfFfQrlbfVdSns5mG4JNYQKSP+/ddD4d+KXjK/8Ahx401a51nffab9h+yS/ZYR5fmTFX4CYOQMcg47V5/wDEjxnP418UG8eJ4be3jEFvG+d20Encw/vHPOPatDwl/wAkh+Iv/cM/9HtQAf8AC7fiH/0MP/klb/8Axuj/AIXb8Q/+hh/8krf/AON1z3gnSLHXfGOm6dqdyttZSyEzSM20bQC2M9s4x+NdpqniT4ZaPqDroPhBr7YSBLd3EhQsOmFJORn86AMz/hdvxD/6GH/ySt//AI3XYfC34peM/EfxH0nSdW1n7RYz+d5kX2WFN22F2HKoCOQDwa8t8Ua22t3MEsmiWOlyIpG2zgMSup6Er7eveui+CX/JXtC/7eP/AERJQAf8Lt+If/Qw/wDklb//ABuj/hdvxD/6GH/ySt//AI3Xn9WLG0e/v7e0jIDzyLGpPYk4oA9FPxW+Ki6QurHVZhp7SeUtx9gg2F/TPl9apf8AC7fiH/0MP/klb/8Axutz4t+IINN03Tvh9pTKbTS40W6k2AGSQDvjj3PvXkdAH3/RRRQB8n/FTR73Xvjnq2m6fGJLmZrdVBYKP9RHySegrS8PaX4W+F+oprHiHVodR1u3yYdPsWDrG38LFvX/ABrA+NTsnxg1xkYqw+z4IOCP9Hjrz6gD0/XPEPw18UTz6neaTrGm6hKzPItrOHWVj3JYED8BWf8AFbURearotvDE0VlbaRbLbI4+YIUDfMe5yetcBWz4k8QS+Iru1nkhWIW1nDaoqnORGgXP44z+NAHT+Lf+SQ/Dr/uJf+j1rgK7/wAW/wDJIfh1/wBxL/0etcEgUyKHYqhIy2M4H0oA7n4ZWx1efWvDzyeTFqtkYlmYfKsqsrJk++CMe9cVd2z2d7Pay48yGRo3x6g4P8q72x8W6Z4P8MXunaFfNqd3eyRSLLPYLEtsUydwDbizc4Bzgc8c1wFxPJdXEtxM26WVzI7Yxlick0Ad34T/AOSQ/EX/ALhv/o9q8/U7WDYBwc4NegeE/wDkkPxF/wC4b/6UNXn1AH1FN4RtvGms+G/GCxm80+HTS8lq5DLI6r8iAdOSTn6VT1n4g6l4c8N6lLr1npllNIvkaXokYEhA/vyjpgenA9q8Q0j4g+JND8O3OhWN+0dhOD8gGGjJIyVYcjOPXua5qSWSaRpJXZ3Y5ZmOST7mgD0D4LOZPjHoshABY3BIUYHMEnQV59Xf/BL/AJK9oX/bx/6TyVwFAFrTrCfVdStrC2XdPcSrEg9ycCu6+JOo2mmW2neCNInEtlpC5uZU6T3J++34HIxTfhRCYr/Wtahjaa+0vT3mtIlwd0jfKDz6Zz+FcDNLLPcSTTOzyuxZ2Y5LE9SaAI69A+Nn/JXtd/7d/wD0njrz+vQPjZ/yV3Xf+3f/ANJ46AOV8MPFH4p0p5ziIXUZbnHGRXf+L/h7JceLvEM0utWH25pp72Oxgy7tFncSxHCHB715WCVIIOCDwa6688f3d1p1zBBpljaXd3CkF3fwh/OnRQFwSzEDIAzgDNAGDrEekJPEdGmupIWjBcXKgMr9wCOCK67xd/ySL4df9xL/ANHrXAV3/i7/AJJF8Ov+4l/6PWgDF13wTfaD4X0bXbiaJodUDFI16pjGM+uaz4/E+vwpbpFreootuu2FVuXAQegGeBXbeFPHuhTeGYvCnjfTXvNLgYvbXMJIlt/YYwSM579+c4qQeLvhvoc4OjeDJL+RB8txqFwSCe2U6UAP16zl8Y/DEeL9Qt1ttX0+UW81wU2fboz9046bhnrjmszwl/ySH4i/9wz/ANKGrL8YfEPW/GRihu2itrCH/VWVquyJfw7/AI9O1anhL/kkPxF/7hv/AKPagDz+iiuw8GeDrXXLe61jWdTj0/RLFh9okyDLIT/Ci+p9aA
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAABqt0lEQVR4Ae3ddbRt2VU1+oe7u7vzcIdA3SqCQ3CHQEKwQPDgJFUJQZLgwTW4BHdJqm4hwQnu7u4u73dO39XvqLnW3vfcSvi+1l678495xhyzD5kue611nuh//ud//p/r4XoNXK+Ba6mBJ74W8HXs9Rq4XgPnNWC1WRacJ3qiJ3riJz4bTojQkggchJCsc+kzTogwAws9s2olTPGTPMmThI4JcYnwAcKkDQddIuKSME/6pE+a3Cd7sierFVm7ISaWrGhemE3GmSYRsRKj8TBxOI0RAjzHCpDEXCwmV5ayJEu5YoUsulkVPNd9RQ9+qgsy2gKILRwE5QIAZMDbGgMDeEKFenvVdqnFx8eBaaWmL6Kw4LqxEAGI1duh9YtgQIYg47//+78lEXKTDAy/TMjkhghebh0tHQBwiaotBxFBcZjEBXR8RaTVA5CFOIecYSjkm/h0gA+sVoqPkiZLHFNLVTAVhIxL4cQ3mP/6r/8KuKoAFEpxEPEk8X/+539GZ5WTrVpMsAAwoz8DAB0YhUKSkJiSEaGKfklMoZjkVu1MXhNdQyWYuKCGuyBSza2TchA1raSTv0sXvJtbbbPGrswuDPC+KhBp13JStrRTdWGGH//QwYuFxelwYPDRKXDEz9B3hIVfxzR8ZBEwsRUhtHCs2Lt8gvjXKlVVEZzi6LgKc0dR7tRvALidOhHrxGAIcUoUf1L8GEILEcFBSGoCgRSFiMaxW7BkZcGiMHiyzQp/xrJm8oI0nUGWkORbxSe9ZE2R0/j4NlXN8i5qm4SfIjWxEMWUCKDJ+nmooIyQegAnaNdw0KluWjIj4mTuBygGE4BqZUMnhBPz+DEUbRGEFwJDYP7Hf/yHpJ4kJmIvESQagFriZzJ39PvkJo6h3ZhU+DQUXD3Rtiu4y6SBP7LiFSLlLRPBojgOi4NM2RPjCJGNwvhTZ1QCGkas6kJgBixJLXFJAf9VX/VV3/M93/N5n/d5MSMYQKoUMwsO4kSgauZGA06JmXua5sMi1aSsrexkTrrIMGdWFQaTrNQqzkRWyTGisovUkiR+ttkVOhiKaN1xCzOA9jwtJMgSogFGQIcTv6NkloFU8GLaJOVW9v89D1r9n/7pn37yJ3/yt3/7t//6r//63//93+mcSiRjC5MJ/UlcT6p/asZsicg+/dM//fM93/P98R//8d/8zd9EcDofunpK1C4OmkLIqI0nsRKYWJArGPnpuDhEhOgMIT7HnoEXn8HkqihZEanglEUTfMZnfMZ73ete//zP/3z77bdHsIMk+jHjgDiAaFviYMqMoeKf8imfkj/P/MzP/FIv9VLP8AzPwK4G+qM/+qNf+7Vf+7u/+7tKlVjEk1xMFDyJuhqR6gmmGha+XFmprjYK5oTNSk4nfM7nfM4bb7yR/z/0Qz/0r//6r1EyRWK08WHY1MUQBKiOMTTmVFR1IQIOLE4AZ3hIlo4GsSJFUG6sYL7QC73QK7zCK9z73vd+rdd6rWd6pmfSm7/927/9N37jNz7zMz/zH//xH+EzwMi2UkgJcRixdKwFBhCjCGPmHd/xHV/0RV/0Uz7lUya/9KxuTCFWqiHMxMBy5/gPOK6iBWNGYYlHAw6p0KkoHMkw0dWPFlJjxkCUZJqo/uQ+7dM+7Zu92ZsZn9/4jd/4e7/3e/W5BBg6Js78OF++YogJydDi0gtfvT3Xcz3XpUuXFOflXu7l7n73uxs8T/3UT/0v//IvP/MzP/P5n//53/Zt37arcMusiRrdEsEsyCYRcU+JUrRqKKb8cmBaA8Gn4d7u7d5OZ/j1X/91I+eHf/iHq2oh0p/PmDQKPAgCIVAtqQ3EkuJwiCUZcJjiEvgwkkGe6ToXj5XkzlhuZD/u4z7uD//wD60tP/7jP/6FX/iFn/Zpn/aDP/iDpN7ojd5oKkcTyXaFHnTM4Qu1Imsb4gn+u77ru1rHvv/7v39iiM/kCTpW4nkcKJ1kZSW5KokIs5Vz7u9hexYaJiGe1ApmAbRplNQtZrTRbJ0xZh7xiEdYCmKIFGQEozCxXER1tsaaKysaZkz/DTfc8IEf+IGaRl/8q7/6K1Py3//939sURIPe9iEf8iFVUh9C1MrUeYKO88ryHM/xHNaB53me53n2Z3/2p3iKp9iKTG9rZboxARGfnNDm0Mc97nHE//RP/1QpTliBP1jJn1blVDrl60qIp3qqpwqh5QILQVxIc0ZncmNFFmY6fWhIAeYFXuAFeG+pSRIH8Vu/9Vs/+qM/+rqv+7pRhUMqWaWjJ4BDkWJyLwbDfuhDHwr5Ld/yLVtI9Je/m5w1Bkln1IqDzzamsiHOUWdRk7GiIMLCV0WxkroKhqAQPaHFUXLTTTd993d/t1NN+lZgkY1mGiISh6NQ1lVrDP4lXuIlvvzLv9zAMGCMFlPbR37kR77aq73aG7/xG3/VV33VX/zFX1DyS7/0S3p2nNnGtcLiNjcc3lq4OKb2XvIlX/JN3uRNFEcb/cqv/Mrf/u3f2ju95mu+5jHZ8GtlwlpFYS7JMH/xF38xsj/90z/9Du/wDruYVuDBSo3JoCWVG9BiKVmYT/d0T/cRH/EROXdWJMbSVGE2RsQKTFo9bZmYLUF7qy/itYIw9FWZGe6FX/iFKRECloUgnjgi6JYFEvNc4kpUjmFjE/hJn/RJV/KuhWpZopDPpNH8iZp73vOeVs7v/M7vfJ3XeR11ZfvUrAAkbUTxSaUI4aND0BkrkAJm9CcORywpyNWDnS7e6q3eqlmYaIFOmBLnEmdGmRC3xqIn1hOX85Vf+ZVZVVSaJfq93/u9LQIRf5d3eRe7A0rMblN2oWOlpVtyJT/xEz/xZ3/2Z40Qmz1BWf7sz/7sL//yL7W+qfPWW2+V+3mf93kRTBGqRNHQmNuyJKvIEi0ajjWZ3ch+1md91syK2ilFYZCHs41ayCYbF64bblrMMdEV2inwkz/5k9/iLd7iR37kR3QOYLpgIpibgzDDEU/vkxQLFAoIeII1hKOK+fDpn/7pZh0zHBiMACOuVOymMPH5HHUWRW2T5VjWLPpVOAHXRMcoEXajrfGjH/1o25g3fdM3dQAAMInCOG+Yp//hH/5B53MbobU0kvNbSh1ZtIKLhTiDqNoQrTdJtABpBD73cz+3ZeHJn/zJ54m2stEjGaJxrERV6MRpU7TpzFn/8uXLZuLv+q7vUgpF+7d/+zdZ7/Zu73a/+91PfX71V3/1gx/84Cm+S88GqkWEJn77t397R81IudL4uZ/7OQdaR6lv+qZvMlCNImeqp3mapwGoYE0oCzr1EGbpFLPJrSzOLbfc4oRG8Iu+6Iv0t4IXVTHR3MOwSXev3tQakBCmOP49y7M8i/MTLWagqgbTD2wVbrzxRscSe8TUESlBbpBVK4kfpjgAMYBAFp2RoxPoZKY3s46sSBUfDZLCVFjNW+Jud7vbK7/yK5vMfvmXf3mbe3EOJ3nILq9InbtwONNrZt3IpuVP/uRPNPaLv/iLv+IrvqLp5jVe4zXUs3p71md9VgeDLKHcFtQtPZSkkhPHGbn46DADw8FHc8PC9bIv+7J6m51M7uuSG3EASHT0iDsyNVkw8FTJAg6nDhghj3zkI531zf2Kk1zxR33UR7m/0ddtDq0Vv/qrv9qsE0R9SIkg+XCf+9zH2EObU2677TbT8c///M+75OSMjmSWUWlmnHhbwRNWZKWWEKmKSFW2xCu90itZMDXHb/7mb5rsfud3fidq4ye6RPiND8OGGTUVG6GXEib5Ii/yIurLNPAlX/IlqcTqNac6vjvA6R9KG5MxUy+ZyLJGSmg/CIxdRJiaVmUx9IIv+IKKlBlUVpo8cZREObpWom0bB+OXDaUwdxo5W8wJzmIiPjDKK3EqjXhK8X3f9318djOLo0J0C2VxY2u0vM3bvI21Wj9wLRFzxGnTSxYl0ZZcVuTCxGLwEc/+Npd1VhtGYVI5rbEgSanVKAFAhI+GDD3jFPn3f//3w3QxYGWwOTeRvfVbv7Wk1fJhD3uYTU6lUnwOCGWWCDNxlHPegdaQeNSjHvW93/u
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 11/30: 100%|██████████| 469/469 [01:33<00:00, 4.99it/s, Epoch Loss: 0.0182]\n",
"Train :: Epoch: 12/30: 100%|██████████| 469/469 [01:51<00:00, 4.22it/s, Epoch Loss: 0.0181]\n",
"Train :: Epoch: 13/30: 100%|██████████| 469/469 [01:50<00:00, 4.23it/s, Epoch Loss: 0.0180]\n",
"Train :: Epoch: 14/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0178]\n",
"Train :: Epoch: 15/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0178]\n",
"Sampling :: 100%|██████████| 999/999 [00:27<00:00, 36.96it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPquk6TrP2eyg8ny4vssL7d0KMeWQk8knrXH/8Lt+If/Qw/wDklb//ABuj42/8le13/t3/APSeOsn4d+FD4y8ZWmlsQLfmW4OcHy16496ANtfjH8THj8xdZlKf3hp8GPz8umt8aPiQgy+uuo97GAf+067nxV8aYvDGpT+G/DOi2EmmWH+jjz1JUsv3sAEcZ4yeuM964/xd8XP+Er8PSaU/h3T7YyBSZo1+ZXHUj0oA6TxF8UvGVj8OPBerW2s7L7Uvt32uX7LCfM8uYKnBTAwDjgDPeuP/AOF2/EP/AKGH/wAkrf8A+N0eLf8AkkPw6/7if/pQtc34Y8Jax4v1FrLR7bzXRd0jsdqRj1Y9qAOk/wCF2/EP/oYf/JK3/wDjdH/C7fiH/wBDD/5JW/8A8brQl+GGieGbYXHjPxTDbuVyLLTwJZm5xwTx+lbY0D4VaF4LXW7601i/+1OUtEunMEkuO6hSBtH94560AS+Hfil4zvvhx401a51nffab9h+yS/ZYR5fmTFX4CYOQMcg47Vx//C7fiH/0MP8A5JW//wAbp/hlom+FPxKaFCkRfTiiE5Kr9obAz3rzygD0D/hdvxD/AOhh/wDJK3/+N1p3PxK+Llpo8OrXF9cxafP/AKq5bT4AjfQ+XXLfD/wXc+OfE8WmRP5UCDzbmbGdkY9PcnAH1r1z4/eJ4tP0ay8H2iod6rLPx9xV+6B6ZPP4UAZPwt+KXjLxH8R9K0nVtZ+0WM/neZF9lhTdthdhyqAjkA8GuP8A+F2/EP8A6GH/AMkrf/43R8Ev+SvaF/28f+k8lcChCupYZAOSPWgD2W38WfGi80y3v7XUDMk43JElrbeYF7MVKZwexrAvfi98TtOuWtr3WJbeZeCklhApH/kOuU8V+I38S+KbvWVja3Er5ii3kmJRwqg+w9K6G6upPE3wslvtSlea/wBHulhinkcs8kUh+6SeuCfyoAd/wu34h/8AQw/+SVv/APG67D4pfFLxl4c+I+raTpOs/Z7GDyfLi+ywvt3Qox5ZCTySeTXh9egfG3/kr2u/9u//AKIjoAVfjV8Rnzt18tjk4sbc/wDtOk/4Xb8Q/wDoYf8AySt//jdaek/FXRPC+jxWvh/wbZrctEFuri8cymVsDPuATnjOKvjxjpl14L1HUfEGj+Ho5NQiaHT7LT7JEm3AkGVm5KqCMDntQBzv/C7fiH/0MP8A5JW//wAbrsPEXxS8Z2Pw48F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM968Pr0Dxb/ySH4df9xP/ANKFoAP+F2/ET/oYf/JK3/8AjdSt8Y/iXGm99alVB/EdPgA/9F1ynhRdR/4SSyl0vSk1O6ikDJayQ+Yjn0YV9M2upeJbrSpZvHGhaDp2geVmeNlaR2A6LtzgH060AeGD42fEQgkeIDgdf9Ct/wD43XYeHfil4yv/AIceNNWudZ332m/Yfskv2WEeX5kxV+AmDkDHIOO1ct408a+GZ9Jn0Dwj4disbOSYPLdyrmWTHQDPKj2z2qDwl/ySH4i/9wz/ANKGoAP+F2/EP/oYf/JK3/8AjdSD4zfEkx+YNbkKf3vsEGPz8umfCrwKni3Vrm8v4ZZNK02MzTInBmbqsYPvg16pp8fj3V9TW3sP+Ec07RopFB0gLFIRGGHDrgndtyc+ooA8q/4Xb8Q/+hh/8krf/wCN12Hwt+KXjLxH8R9J0nVtZ+0WM/neZF9lhTdthdhyqAjkA8GvPfidY2WnfETV7WwAEKy8qIwiqxHIAHGM+laXwS/5K9oX/bx/6TyUAH/C7fiH/wBDD/5JW/8A8bpV+NXxFdgqa+WY9ALG3JP/AJDrz6vUfhz4judI0h/7A8DRarq8T5kv5FaTbk/KAP4fwIoApyfGT4lxLuk1qRB6tYQD/wBp1F/wu34h/wDQw/8Aklb/APxuvW5Z/EOs+DtWu/ibZWFlo3kM0dukZW4Eg+6V54PoK+ZTjJx0oA+/qKKKAPkD42An4v64B1P2f/0RHWT4V1zUvh54ntdWk09tzRMPJnUr5kbZBx+IPPtW18ZpPJ+M2sSlQwR7ZsHviCOvSdS1n4TeP7W01vXrs219FGsTwNM6MAOdoUHG0EnBGKAMHRbL4dfFS+mtIdLu9D16ZWcGGQvGxHU88de2B9a8e1rTJNF1y+0uWRZHtJ3hZ16MVJGR+Ver6l8WPD3hrT59N+Hugx2bOpQ6hMuZMeozkn8T+FeOySPNK8srl5HYszMckk9SaAO98W/8kh+HX/cT/wDSha6Pw7Y3t98JYdL8IXtjHqF3cNJqzNciKVVHCJz2PJ//AF1zni3/AJJD8Ov+4n/6ULXGaQuntq9r/aryJYCQGcxLubaOcAe/T8aAPVfDHwM1aTU7a88SXFnBpud5VbgM85HRAegz61yfxJtvEs3iSe91vS5rKFQIrWMj93HCvyoEPpx/OqnjbxrceKr6GKCP7JpFkvl2NmnAjUcZPqx7mufutU1C+ijiu765uI4hiNJpmcIPYE8UAdr4S/5JD8Rf+4Z/6UNXn9egeEv+SQ/EX/uGf+lDVw9i0CahbNdAm3EqmUAZymRnj6ZoA+mPgtosXhX4cXPiC/hSN7pGumcfe8hASPp0PFfPXizxFceKvE15rF0ADO/yqP4UHQflXqPxL+MGna14Xg0LwsbmCJxtuXaPy/3YGNgAPQ9/p714nQB6J8GvJHxm0b7OXMObjYX64+zyda87r0D4Jf8AJXtC/wC3j/0nkrz+gDv7jwNo3hzRrO88V6xPBe3sImh060hDSKp6F2PAyPatW50CHX/Clno/g3X7G4s4pDLPZ3O22uZZum9s/eGOBzwKzV+I2natp1nbeLfDFvrM9pGIo7z7RJDMUHQMyn5vxrkNd1ca1qTXKWVrZQj5Yre2jCqi9h6k+5yaAG6voOqaDOINTs5LdznG7BBx6EcV1/xt/wCSva7/ANu//pPHXA+Y/leVvby87tueM+uK7742/wDJX9c/7d//AERHQBwAIDAkZGeR612w13wAkcJHgy6eUIolD6m+1j3IxyM/WpIfC/gjUNMhubfxqbKfyh51te2hLeZjnaVIGPSrUWifDrQrZbvU/EFzrsx5S0sI/KU+zsckfhQBHrmg+FNS8HT+KtD+1aQEuBbrYXT+aJW2gny268Z5zmk8W/8AJIfh1/3E/wD0oWua8ReJbjXpI4lhis9NtiRaWMAxHAp6+7Mccsck10vi3/kkPw
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAABdRklEQVR4Ae3dd7wmS1E//h+I4WfOOWfFnLO75oxixoCigqiIGREMXBQUs2LCgDlgzhm8FwPmnHPOOSvq9737eW5tbffMPHOePXfhddnnjzk13RW6q6urqnt65tzh//7v//6/G78bGrihgbNo4I5nQb6Be0MDNzRwWQOijd8d7nCHQR9zyYBwpttIORPJCchPzFJKnwGG27mz17Mv1Zi5GRsld7zjHXcSXs++bDT4XKrSl0O0cYMpRRTrlNTtWYGdCj0r2zPhn0sbzoWJZpc+Awy3pJyXoKMqmgVVY47SdoT//d//7YQD225Lneo0eGB+GpM1qhOYH+bJkz3Zk2FKEWuslQ/ch9uBsCt0qLput9fYhnRwm8m2EnpPmVEsKVdVoS0pXdB+tl3ETrgLWiNZa8BaOT4D2+F2UdAGtwEft/3IA+3JhMVn5nCn1G13EhmEAWe4LRm3G+BoB6OW9DeaLZK3eIu3eIEXeIFf/MVf/PEf//EglEsCdH0WCbSUd7bbytyPOfA5Sthb1WnXyjtOmK9hlugCOu0GvMZwgyRVCLdlHeU8IxymTQ0qSbOMmexoW68bwtzak0W/8Au/8Ku/+qs/xVM8xR/90R/99E//9L/92791tRxlW1p61Vd91fd93/d9iZd4id/5nd/5mZ/5mX/4h394/OMf/5u/+Zu/9Vu/9Sd/8iczH11QiDwccp3R5pL9mAPtyYQDn8XbbeZVW8Aik5MLBfMatdhGWUgBO5lv4F8KI7jA6Lw6QW9HcJ7+6Z/+zne+8/9/+feP//iPf/u3fxsPyjj+53/+5053uhOD+/u//3vX4rkopWqvEajWXouUp3u6p7v//e//wR/8wU/zNE/z67/+6x//8R//7d/+7f/1X/9VzKuRXYrkVpdTFcwXeqEX+tRP/VTwH/7hHz7jMz7jS73US4Gf6Zme6a/+6q9+4id+4od/+IdNJBorbgMQbSPJ2AMGhH6rqc/7vM+rPRj+y7/8y3/+53/22j1w7wt84lIy0D7t0z7t8z//8z/DMzzDkz/5k9PJUz7lU5L1sz/7s//93//dMZ/5mZ/Z7T/90z+xhF5eUtb4B5mWnv3Zn91A+PE1v//7v68ckNptWjglJfjbVwP3fM/3fM/5nM+pL3T+z//8z9ylrvGYRmqg7aIP+vHnAF3GhTHQDLfP8izP8oEf+IE6E0LAD/3QD33Hd3zHD/zAD3zVV33V13zN13zP93zPwx/+8Dd4gzfQoKINct3eRsC1SHnQgx70x3/8x+HAZHWENhfbuSGFVVHFX/zFX7z927990TKF13iN1/jar/3aX/mVXyHiIz7iI576qZ+6agHRecSV0A0pSAhixO/3fu9H+Y997GMf8pCHvOVbvuXrv/7rv8iLvMhLvuRLvtqrvdprvdZrgbuURXhbShpjzrzjO77j933f95mZf/mXf/lzP/dzqNj0G73RG3GgvKT2m8AC9f3ud7/73Oc++ludSte2paRhL/iCL0gzN998M4cLnw4/8iM/8mM+5mMuXLjAoBcbPxTukYLkqZ7qqZ7ruZ7rzd/8zb/sy77sb/7mb0LFxXN/f/Znf0Zo2K5NhIOUbWEhLhb8wT3ucQ/uhJvxY17IOxxuro95zGPYSnVsW0qh7QSYndGCXEYWwkFKNXubLSYs48d+7MeQp0e8zjd90zfR7yJhSRmUAxkfo/7Qhz40pqMkey2uDP3lX/7lLXUe+chHCkEDZ6x6X8AlZcB0+xzP8Rxv/MZv/CM/8iNpbTDn6zd+4zfypjN5L5mlDEqj6s/7vM8zxDP/f/3Xf/22b/u2u971rgaafUtB4XzJl3yJXLeLAHcpvZuF9nqv93qyYmhsV/7y53/+56ZoqOjznd7pnQpzBophlzKjpUR33vu93/snf/InxcPgk9jV+Ku/+qvv8i7vskauPFSHtc0aHqSguj7bsz2bDnzap31a7Omv//qvOTyeQLz+j//4Dylbd6K6rTBsh5FYk3W0nIJe7uVe7u3e7u14U+sHrkJI/dzP/dzv/M7vnGkJTePnqqGE1vQi2k9T+QXOtdo/4NdtV04KmanEzAzxS4lRwZMIgB0C3vRt3uZtuJ7P/uzP7uucDEbaANmvpHQAwuu8zut8xVd8hf2G0ir9a780QwT4u7/7O6mgzI01J+yIfp3DUXhQGg/ye7/3e/KW537u5x5oDfdd7nKXixcvsuzneZ7niYNgkQxjwOy3c9ckJmam1grF3/qt30rzv/RLv4SbgPn+7//+4hhnJG3TwcUxnRl2cQNM7e/5nu/Zw5fe0ZsRp0a9EDPl2DxCN6pZ7pFp06W+zMu8jJmaOaOHn/zJn0xfBBh+3uilX/ql3/qt39rkYRnSM52stc0wEp3nNswK3/RN39TAsBLjhyEpxo8UapUuWnbz372HxXC/UD0Smiul4e34adGmWC0CsyqhEaqRsqYyHYYezLSHM9YFps84+rRBC+2oBeAm4Mtn4NP8N3zDN/COEqff+I3fYNlUxIla5zz4wQ9masy6Zi/8/b/eNQyl3BJvMYTaNYA3+fd//3fWfPe7311mKE/zw/zXfu3XvvRLv9QsNfQbsjpzKYPE8mEPe5g5Y6pYT0r10wsc+EQLKlnfm73Zm4lFv/zLv7x/THsDItGVNijNwtUQS5j5F3rjfHXHqGm2dODzP//zOeUP+IAPYMbmUvgsyFWUUny7sAGmrwc+8IFB5jhNWTZNjzUw4Bd7sRdjx8ZV1bM+67NWFVYhBGxLgWA+GCFZOwuT6VlYM6aQV2DNLeu8733vW+bepYDXfnMDGJlMWqfCliHKp9fIlVdfZhyWilYSlYx/loXkEz7hE4Sdj/7ojwbT6swkJWtSeA29fud3fmcGxzsy2R7k0b7u674uazBvb7rpJmOxxn9byjbVK73SK1kY/MEf/EEayf7sHOp45s9Mu9YXMcQ0o3m7C301KBrIKcx8UtRK/4Tome1Qsialo/E4r/AKr8BzWd4MSTir+9iP/ViRBx8xc+5LRjNSrkQb913AAOvGu77ruyrkDCQJ3/zN32zidhxT9rd/+7d7SWDCOucOd2QWwAcbD4mfDRnz/kVf9EVZFXGf9VmfZRlKfYKY3t7tbncT90Tbr/u6r2N/R1OpLgWcBqRVDM4iQSZgkgdNbvn93//99bBloN2+xZPTsuQoNO2f+2vdycXwr9COhpdiVQByKQ0PUiUdsEj44i/+YpOWC7dDIx3otWvwMEaLaFQkQ36VV3kVA2Qf1ZKacdO/iGdZSGNrTRq4lSyNNNNkaEoMpZgp5Xvt135to6/qlV/5la0IINDh7/7u71rtDHxOuzXVZ0JpCzPgiSzSePw//dM/tRJhCYWZNvehvDJtCmkGxI23fdu3NUdVcWNMdpgzM0mVdGFVOAByMLZrKZb0I7VGwtT/hV/4hUc84hEJlzR473vfmxMSyqWIQsTAZ/9tWmXhbvcmc4YF4+/pSvIErGqAF9mKimYyErW45TdgLs4Ke48SA+5gQN5/2w10aCT94PwZn/EZX/3VX30uY4S/8GWevPiLv7jshWGly2avrOmLvuiLrEZ6y4f29KrAFBWAo8zShRo5EfPE6BsR5dmD5RyDKV/iN4vVURGFuQgIMixZ+mcdaCEqNSDI/HTVNaPDAGTpnbbaXIW7pg3vGALe9Kd+6qekMUW/BmjBotHM+Jjb0b7Xve4lLJrosjJSDLkkTRomDQ0Jl/Ae7/EerFytByxnmjOLitZC6hOy2Rn35tb4GR5ZbyQOyhp6BBlazZZFEb2zhcCfsftE5uJZtZ1kDzw0kjflUASBPbRHcVjzK77iK0r0BQEthK/X8S8igAnDy9goyniF29CeNRFsl6X6hacNHt79R3/0R+lfQvGar/maSTcIxRyyEMfKzyQiyINiE+I8nTNJrM9lvN1T887mjBVaaLeuNfAlZhFbogyTj7E3bzm4iDMURtEpHKQUpjhmWWmSQPAo0CpTN/StEAqw+2TAfv7nf95TqiocgDUpA1puPZfkkpHYz2DHADm0SL2I3AsHKb2b0Ib
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 16/30: 100%|██████████| 469/469 [01:49<00:00, 4.27it/s, Epoch Loss: 0.0175]\n",
"Train :: Epoch: 17/30: 100%|██████████| 469/469 [01:50<00:00, 4.24it/s, Epoch Loss: 0.0177]\n",
"Train :: Epoch: 18/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0172]\n",
"Train :: Epoch: 19/30: 100%|██████████| 469/469 [01:50<00:00, 4.23it/s, Epoch Loss: 0.0173]\n",
"Train :: Epoch: 20/30: 100%|██████████| 469/469 [01:49<00:00, 4.27it/s, Epoch Loss: 0.0172]\n",
"Sampling :: 100%|██████████| 999/999 [00:26<00:00, 37.12it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPquk6VrP2exg8ny4vs0L7d0KMeWQk8knk1x//AAu34h/9DD/5JW//AMbo+Nv/ACV7Xf8At3/9J46yvCngW+8TWtxqLzw2GkWhH2i9uDtRfUD1PtQBq/8AC7fiH/0MP/klb/8Axuj/AIXb8Q/+hh/8krf/AON10Nx8OPh1ptkb298dtJBId0Atwpdl+mDk/lWJqOq/DnQ7Ke30DSbvV710KLd6k+ETPcIuMn/CgDqvEXxS8ZWHw58F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM96xPDXxO+JvijxDZ6NZ+IlWe6faGeyt8KMZJP7v0BrK8W/8kh+HX/cS/wDR61P8D9IudS+JljcQkrFYq88re2CoH4lv0oA1/GXxF+Jfg7xPdaLceJxO0O0rMthAocEZzjZx3H4Vg/8AC7fiH/0MP/klb/8AxutT4npf+Pfitd2mh6ZLPLaRLbERjO7aSSx9PvY/CuJ8SeDdd8JNbjWbB7cTqTGx5Bx1GfWgD1Lw78UvGV/8OPGmrXOsb77TfsP2SX7LCPL8yYq/ATByBjkHHauP/wCF2/EP/oYf/JK3/wDjdHhL/kkPxF/7hn/o9q5/wj4cTxHqkqXNybWwtIWubycLuZI167V7sSQB7mgDoP8AhdvxD/6GH/ySt/8A43Ug+MvxKZSw1uQgDJIsIOP/ACHVn/hOND0a2e08KeCLOQKcG91OE3Mrf7WDwp9ulcxqnjzxHqaywy3gto5F2SRWkKQKw9CEAz1oA9K+FvxS8ZeI/iPpWk6trH2ixn87zIvssKbtsLsOVQEcgHrXH/8AC7fiH/0MP/klb/8Axuj4Jf8AJXtC/wC3j/0RJXM+FE0VvE1n/wAJDK8elq+6YopJYDnbx6nigDpv+F2/EP8A6GH/AMkrf/43R/wu34h/9DD/AOSVv/8AG60rjS/hNq13NcQ+I9WsXlkZ9klqojQdcAAcD0rN8bfDmHQNBsfEWh6kdU0W6wpmK7WRvf2zx9aAD/hdvxD/AOhh/wDJK3/+N12HxS+KXjLw58R9V0nStY+z2MHk+XF9lhfbuhRjyyEnlieteH16B8bf+Sva7/27/wDoiOgA/wCF2/EP/oYf/JK3/wDjddP4Q8bfFLxeL+WDxPHa2tjAZZriayg2A9lzs6mvMfDPhfVPFurJp2lQGSQ8u54WNf7zHsK7jx5rlp4Y8Ow/D3w9dGSGBi+qXacC4mPO36Dj8gO1AGd/wu34h/8AQw/+SVv/APG67DxF8UvGVh8OPBerW2s7L7Uvt32uX7LCfM8uYKnBTAwDjgDPevD69A8W/wDJIfh1/wBxL/0etAB/wu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N15/XX/AA68D3fjnxJHZxqVsoSHu5uypnp9T0FAHRw/E34tT6LNrEWoztpsLbZLkWFvsU+mfL9xW94d+KXjK++HHjTVrnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtWR8XvF9kRD4J8PRrDpGltiQxtkSSAdM98ZP41jeEv8AkkPxF/7hv/o9qAD/AIXb8Q/+hh/8krf/AON0f8Lt+If/AEMP/klb/wDxuvP69V8O/DPw/ceGbC68T6zcaLqWoyN9lSTaI3jGMHBGec8c0AZifGn4jSNtj15nbGcLYwE/+i6674W/FLxn4j+I+laTq2s/aLGfzvNi+ywpu2wuw5VARyAetXfCfhDwX4A1X+19Y8ZWV3IYZEihj+VWVlKtnkk8EjtXD/Bso3xo0gx/6svclfp5EuKAIv8AhdvxD/6GH/ySt/8A43R/wu34h/8AQw/+SVv/APG6zPAVt4TvdTurXxXNJbwSQHyLhXKiKT1OOtdmfhL4TvsNpnxBsmRyRH5iBjkcnOCO1AHP/wDC7fiH/wBDD/5JW/8A8bo/4Xb8Q/8AoYf/ACSt/wD43WunwYtryVrXSvGuk31/s3x2yqVL/jk4ry25t5bS6ltpl2yxOUdfQg4NAH31RRRQB8gfG3/kr2u/9u//AKIjrQ+FvxA0/RbS68L+I4Em0HUGJZ258piADn/Z4H0PNZ/xt/5K9rv/AG7/APoiOuf8LaxomkXTvrXh6LV4jgqGneMofwOCPqDQBv8AxI+HY8IyW+p6Xci90G++a2nU52552k9/Y1wFegeM/ia3ibQLfw/p+kW2l6RbvvSGPk57c/ifzrgKAO/8W/8AJIfh1/3Ev/R611P7O6TjWNemgKb0sgFDDOWJOPw4rlvFv/JIfh1/3Ev/AEetcr4e8S6v4W1JdQ0e8e2nAw2ACrj0IPBFAH1F4P0jWtE8F6lqsOlRnxNqczzSQyHygW3FVyDnAA5x3zXIfG27ux8LdFttfe2GvSXSvJHCeMBX3Ee3KVzEf7RvipEVW0zSHIGCxjkyff79eb+KfFGo+L9cl1XUnBlcBVRc7Y1H8K5JOOp/GgDpvCX/ACSH4i/9wz/0oauR0LX9U8N6kuoaTdtbXKgruABDA9Qyngj2IrrvCX/JIfiL/wBw3/0e1ef0AeqeM/GvjHwzq40+DWLVLeeGK6R7OyijEgYZBOF55zVa7S28efDq+1ySKCDxBojIbqSKNUF3C5xkqoA3A45xVXStY8OeK9EstA8TMNMvLKPyrLV41LAJ12Sr3Hoe1TX48PeCfDmqWGl+IY9c1DV7cQM1tHtigQSKxJJJyTtxj3oAg+CX/JXtC/7eP/RElef16B8Ev+SvaF/28f8AoiSvP6ALmlWcOoanBa3F7FZRSNhriX7qD1Ne3J4b1uP4Oa94cvkWW0sk/tGw1GF90M8YO8qD69ePevIvCmiafrusC31TWrbSbNRuknnPJGeijua7PxX4807TfCa+B/B8kz6YhYXN9N965JPO0dlP8qAPL+9egfG3/kr2u/8Abv8A+k8def16B8bf+Sva7/27/wDoiOgDsvAGmXN18JLm38JSQtrl7c7NQfeFlhhzjC9yMYP4msvxBrHhvwT4gXwvZeELTUbW1Iju7i+jJuJ2P3ihB49uK8u0rVr/AETUIr/TbqS2uojlZIzyP8fxr0EfG/XXMM15o2hXl9CMLeT2mZc+uQQB+AoA5bx14fXw94lkit42SxukS7sw2ciKQblBzzkZx+Fbfi3/AJJD8Ov+4l/6PWuS13X9S8Sam+oarctPcNwCeir2UDsBXW+Lf+SQ/Dr/ALiX/pQtAHAxxtLIscalnchVA6kmvoC6uF+D3wgSximjTxJq53nAUs
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAABt/ElEQVR4Ae3dZaBty1E1bNx54cPd3f0DgpyT4O4WnBAkQQIJECBwBQuWBHf34O7knouE4Brc3d3lffYea49Tp+dc6+4L+fee+aN3dfWoqvau7tlz7cf/7//+78e7+dysgZs1cHdq4AnuDvgm9mYN3KyB8xqw2swF5/Ef//GxEyKe4AmeoJwyU3Oi4SQMEz4PJqJhrWAGeclw4qch4o3G0BM+4RPWSpQXcElbJ2ApSACnrcwMT4W7mQH27CbFSpNCAFcnTh9MdJNmdIpMQOjFijoMv9oqXk6VxHqjiIDDF1bbUmNTZNJbEzP1LulaiZ5qQ5SuEhy5TYbDDAddYoLLP1hpxUF7qgKdYld1CPwza+ftFzwYIrRQUuhYisIWqfoRExn+Nqyq4LeAyV+s7IL/Z8yWjvhpK5ABQ5Y4ZjTVuJu6WKGqvXBRO6OT3qpNhU9MrIRTfompAXPJ7VZqK4jjWcoStfhT/yXpE1KxIpO7+dzVH21bnbscTM+hLNsiJZkZRMPFapKammiyK1w04MRKpaptcp7iKZ7ivd/7vb/ne77noz/6o5eSF1+L4UwYVduyTMHHFX2sLNE/S7RrcRewZR4rS5BbfGy1Qt7kTd7k1ltvvde97jXzILWA8BcrVVtiiofOLBn6GKz8EIuVrc675FThCeTWymmppBbT6MJhsRxlP1hZjFUYelbx5C9JYNU7bYSWBLBYkTSfiN/nPvf5jd/4DchHP/rRz/zMzwww1U586QWwa2XBVPYYcRp/l2WZareqypl1O0VK75YlqZRET0LMEgG8/Mu//Dd/8zf/3d/93Rd8wRfgbG2VEyuLtii5ZLiYrlR13q0aq/jdImJrluWY+Da3zecxEfxKIWLliYpO8UQR//Vf/4VIOItNJlpCRDaKIhW9CZsaYoYTgE8D8bd5m7d5/ud/ftGneZqneZmXeRnLzrQyxUvfJSDKi78McVpn6uQyenZNV3n0qIeoKv8ymgveEhF/+7d/e+vMkz3Zk1V/+LPa28qSoidhMWn3CDap4EUhW/e85z0139/+7d9+//d//x/90R9NnQHf3bA5uUvB2Aos9CzdMfHoT1gNW6OTU9jBU5d2TDt+UreYhXMMRgN7102eW1pkn+d5nudHf/RHA/uDP/iDBz7wgZe3GySFET9Xf/eCJ37iJ37ap33aZ3u2Z7PKPfuzP/srvdIr/f/nz1M91VNR66m6dKalLE29S4IqGgKbard0y1I8KTAPjrC2znnXo8bML/3SLxG/4447Xvu1X7u2pkiZsbJNqnLEor+yxfyf//N/3vzN3/xXf/VXafu1X/u1V3zFV1wwLUtF/seEKfVN3/RNn+/5nq87vajatksKNWtvMRpAy74Qi/7KpiyH1eaJnuiJ/v3f/z2SEgoKgSMpfPnojIuTJvzP//zPigRG4VM+5VPS+U//9E+7WY8sVdH893//9//8z/8cJarmJV7iJdAwtZukRLf8GG0eZibLnIRK0diGCuIZn/EZn+mZnklLGDN/9md/Zgi9xVu8BfDf/M3ffNAHfZAuOAVb9sks/SRP8iTP8AzP8Cd/8iddTJaMiYazlCvM6Jk0GFWzOGfyFw00ayOyL/7iL/5pn/Zp8qAgX/mVX/l93/d9kSUC3HyGKKcK8cPEiWBMLIKBRcrMoh9/2Id92Au90Av927/9G+u/+Zu/GXwAtRLmMYVbE1sOWVb4n3a/n/EZn/FXf/VXxbRdWlexXj5kTC+A8oMXRUDOLt08Iw4WgYLL8KJU2vXkA+rsT/lL6ozqc/Q853M+573vfe/bbrtNhRI0hGJlKLtOVu23fMu3gP3Hf/yH8Od//ueDmMqvy9xIyXMYJ6xMCfmxnnz6p3/67/7u7xJRs4x6EOw+4hGP+JAP+RBOjnE1pUofs6LgL/dyL/elX/qlT//0T18w4i6L0PxPqVjZTQqsaluBPKXP+qzPMtrJ6sdT2y5NeaxUFVi1lZiyZUZE+Gqv9mrf9m3fRo/a+/Ef//EXeIEXmPjQsbLl310Oc/e4xz3s2XSVF3zBF1zEW5bkTaqGNpF5cMpcpLbRLXJyDmXZFilNlVFEKZljjbckPfVTP/X7vM/7POxhD7vzzjs13u/93u+l8ajaWkl2Z4a++qu/OjDhL//yL2/Lk8xs+VVyzMrMP7/lO7/zOy2DASN+5Vd+5aEPfejrv/7rX7lyxSqXWjb+p6GaaFebqWh8zsmjHvWor/qqr3qO53iOJfXuRpnbLUuzEYWieURVsrngH//xHwl+5md+5ok8EKn4MSvFNOc1VA7CtPgTP/ETBoxq1Jvf/d3f/Umf9EkLqJJdK4UthPpfOI0+3dM9nfX/X/7lX9h61Vd91fJDbK1cvXrVkvtxH/dxzkgWcPO28BNN6haDEyvXjwQIqHprkwR0Fim9QaWEg2mn/i7v8i7Xrl1z5KWfvfALv7ABbZ42v1o9eWVm6Cd/8idXd7/wC7/w2Z/92V/+5V9OigZhH7Y9YVazVM5SMP/6r//6Uz/1U+hYryBi4svfZTZ1Kvn4j//4d37nd+bD/NZv/dZP//RPc8C++7u/2/71H/7hH3S4uS7rB9Uw7c6yKEUwMqBFX+d1XkfH5Tl87dd+LVXmfp6neuAvyQMXFObHfuzHbDnwIzjzNs2VZqKlmzQAWY/KD/h5n/d5P/IjP9IhvkI98pGPtD+skhBp3EVJMVS1aCy2aAUkG1OcT/EO7/AOr/AKrwCjdNqa+6TtFhGaw0kGmrpLvN7rvd7HfMzH6Ehax8r/O7/zOzrbH//xH+sb/E9TnnlZHs5WkOOjK5k08VkJ5fBP//RPeeB/8Rd/YS7+/d//fS4GZ5L4r//6r3dfsGQmhW3Nt3JaLYdhE0uz30xFRgUzBsm7vuu78vvlXu3ofPYGYKpDOYPXZnLzIz/yIz/8wz+sx+iLNVmFctMMlflar/Vaz/3czy2qv5qzP+mTPilJyVhhJ4gWacGkN1gJud0qUQ4/9VM/1fHDn//5n1sS1eaCPx2dVhStNWbWuN/97mfKcKLgqZJXf/VXNyZNLrLheeu3fuvP+7zPcy7MNFU4RYZYypuKCnNWWrLRMfOSL/mSn/AJn8AWJbfffrv6X9Ruu2w0FBblbaxpqxhE+cTf933f1+kZpm79gz/4g5zebDaqJIIVaV1NhQttp6QsOpuO/iqv8ipUmUAJPsuzPIv+9qzP+qzpcse6O22pK4SORPYnf/InDew3eIM3oMRsZactRKsQh37GD1Ws6M8e070ey9X8pm/6Jh1DzlNLaaZqpvwwbCBas0ELW/7Xfd3XfaM3eiNuK4dS/zPuW1owZvRFTWhaNZX+9V//tU2hPLUMsV2RXeLN3uzN8q6GoPc22dsku/ri/3f+6I7KZsJW2iovISe7mjHl+R3f8R2tM1/zNV+jyxoz9BwDn+bXiuqa5TIIv/Ebv9FhhnaSvRzKUWX5fdEXfVF1YvHRfspx//vfX8sZPFUVi2mVhZkkzNlmmDiepL7Ga7zGAx7wAPM0tabquKCSkr3AEk4lOAHUBEKFT6kkNZzi7/Ve7/W2b/u2urJU7W5K4usGOeeCaaJ6ThDWTKkayBwaOpPpFNF2pmbzzmTWUKtFqiXLwstf1VjWqwy5KbVLc+14DZ/4iZ/IpZplmZqvrzapMmkdLTi6LL/QvOI89hd/8Re/7Mu+zCB+6Zd+aefF+rclz5JiqOQ1JY9zKcy01CxG/2yD53qu5zK1ZMmK2oJf9mVfVodgjodjyjEZfN3Xfd0Xf/EX0+wJLEQrrrIldGLrJH/sIz7iI6z75Z8goq0mtkhJs055feZ7PpITWGBldDoHo1Av9mIvplAWHHcglEU/0OE4DHaAU+20Vbq1VE5EwjfR2CJ7Tcxhlplv//Zvt0mzuAUzRZLVciI+oxXBTNmDmWHxzksM/nRoK8y3fuu3PuYxj4kGHZStIqstqSdC0wovxuzJp7LL5WK91Eu9VPYk6pDaP/zDPzQ7m5IcVGrE3/7t357aanEy//Iv/9LhARFZ0sfMWdYr48cKo5YMS6WDYVr
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 21/30: 100%|██████████| 469/469 [01:49<00:00, 4.28it/s, Epoch Loss: 0.0171]\n",
"Train :: Epoch: 22/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0168]\n",
"Train :: Epoch: 23/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0170]\n",
"Train :: Epoch: 24/30: 100%|██████████| 469/469 [01:49<00:00, 4.29it/s, Epoch Loss: 0.0169]\n",
"Train :: Epoch: 25/30: 100%|██████████| 469/469 [01:50<00:00, 4.26it/s, Epoch Loss: 0.0168]\n",
"Sampling :: 100%|██████████| 999/999 [00:26<00:00, 37.03it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPq2k6TrP2exg8ny4vssL7d0KMeWQk8knk1x//AAu34h/9DD/5JW//AMbo+Nv/ACV7Xf8At3/9J4647+xNTGijWPsU39nGQxC52/JuHbNAHqF/44+MWneH7TXZ9Sk/s26QOk6WVuQoJIAb93wTjP0rA/4Xb8Q/+hh/8krf/wCN1sfCrx3O06+C9ahfU9J1RhbpG7ZMOeOO5HTjtXCeM9A/4RjxhqmjBtyW02EP+wQGX9CKAPU/EXxS8ZWPw48F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM964/wD4Xb8Q/wDoYf8AySt//jdHi3/kkPw6/wC4n/6ULXJ+HdAvfE2u2uk2Cbp53xnHCDux9hQB6fp3jP4w6p4ZuPENvq6/2bbsVeV7W3XoMkgeXyO31Nc5/wALt+If/Qw/+SVv/wDG66b4m+JLLw14Vs/hzoVy0v2QYv5xxubrt465JyfoK8gtbK6vZPLtbeWZwMkRqTgUAe0+Hfil4yv/AIceNNWudY332m/Yfskv2WEeX5kxV+AmDkDHIOO1cf8A8Lt+If8A0MH/AJJW/wD8bpfCYK/CL4jAjBH9mgj/ALeGrj9C0W88Ra3a6VYR77m5cKo7D1J9gOaAOzh+M3xJuJNkGtvK+M7UsICcfhHTZPjT8RonKSa8yMOqtYwA/wDouuyvJNK+BNj9nsGh1Lxbdx/vJZAdkMZ/2QePb1xVDR9W0z4wY0nxHam28QZ/0bU7KAYfj7sgHagC38Lvil4y8R/EbSdK1XWPtFjP53mRfZYU3bYXYcqgI5APWuP/AOF2/EP/AKGD/wAkrf8A+N1tfDzwzfeEfj5pGkag0LTxidt0TbgQbeTH4+1eTdqAO/8A+F2/EP8A6GH/AMkrf/43XU6v46+I2geHbLUdT8Yww3l6glgsFsIGk2HozHZ8tcP8PNItbvWJ9X1SISaTo8Ru7lW6OR9xPfLYFYXiHXbzxLrt1q1+wae4fcQBgKOyj2A4oA6z/hdvxD/6GD/ySt//AI3XYfFL4peMvDnxH1XSdK1j7PYweT5cX2WF9u6FGPLISeST1rxCu/8Ajb/yV7Xf+3f/ANER0AH/AAu34h/9DD/5JW//AMbo/wCF2/EP/oYP/JK3/wDjdcHFDLcSrFDG0kjHAVBkmvQdA+EWp6lYNqGs6jaaFZgZVrs5dvfZkED3oAi/4Xb8Q/8AoYf/ACSt/wD43XYeIvil4ysPhz4L1W21nZfal9u+1y/ZYT5nlzBU4KYGAccAZ7155rXhLSdMtZZrXxfpmoFAcRwhgzEHgAH19a0/Fv8AySH4df8AcT/9KFoAP+F2/EP/AKGH/wAkrf8A+N0f8Lt+If8A0MP/AJJW/wD8bpPhv8O08YG71LU7wWWh2IzcTBgGJxnAz0+tegaDe/DHW/E9r4R0vwn9tglTYdQdipJUEknHOPcEdRQBwH/C7fiH/wBDD/5JW/8A8brsPDvxS8ZX/wAOPGmrXOs777TfsP2SX7LCPL8yYq/ATByBjkHHavOPiL4ctvCvjjUNKs5A1vGwaMZyUVhnafpWv4S/5JD8Rf8AuG/+lDUAH/C7fiH/ANDD/wCSVv8A/G6P+F2/EP8A6GH/AMkrf/43XAKjO4RFLMTgADJJr1PxV4DsvCHwl0+61K2VfEF5dg7g5yibSduOnSgDL/4Xb8RD/wAzB/5JW/8A8brsvhj8TfG2vfEjSdI1nVmls7gTGSFrSKPdiF2HKoCOVB4Paub+Evhy08298Y69CDoujxtIBIuVmlxwuD1x/PFT/DbXbjxH8fdN1S4eRjM9wUV3LbF8iTCjPYelAGV/wu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N15/RQB69b+OPjTdaQdVt5b2SwEZlNwunQFdoGSf9X0xWD/AMLt+If/AEMP/klb/wDxuvTfhfpGoeEfBVz4q8Sapc/2f9jMlrYtOzRrGVznYTjcRwBXzrMyvM7IMKWJA9BQB9+UUUUAfIHxt/5K9rv/AG7/APpPHW1p9/NqP7OOoWMMRZ9P1FQ4Tr5Z+fcfbLEfhWL8bf8Akr2u/wDbv/6Tx1D8OfiPJ4ElvIZtPj1DT7wDzoGbByO4zkfmKAOl+CXgG71LXYfEt5BPFY2LeZbsMDzpB257CvP/ABtqt1rXjbWL68BEz3TrtIAKqp2qvHHCgD8K7rxX8dNV1jT203QrKPR7NwVfYQ0hUjkA4AHfoM15MTk5PWgD0Dxb/wAkh+HX/cT/APSha7v4AaGtxoWv6jZ3CQ6s+LaKYru8hSM5x3yf/Qa4Txb/AMkh+HX/AHE//Sha5LR/EOseH5ml0nUrmzZ8b/JkKhsdMjoep60AeuTfBO207UBfeMfF1tBFcTHOBteVic9ST157V1LeKvCPg2OZdN1DR7axgypt7GPzLi+GzADMfujdnNfOuqazqet3P2nVL+4vJsYDTSFsD0Geg+lUaAPRvDs63Pwu+Jk6psWWTTnCf3Qbljip/gbPbweOpVZ1S/kspUsC5wvnY4z68Zql4S/5JD8Rf+4b/wCj2riNPvZ9M1G2vrZgs9tKssZIzhlOR+ooA9s8Q+GPBuh68t78QPEtxq+sTyr51vbKEUA92A5Cj2IpfiD4wuvCFtYL4Dh06x0O8h+S+s4VaSRh1VmIPI468nmuW1TxP4O8eSRXviSK70nWNu2a6sU3xzY6Eoeh981keMtf0c6NYeGPDTSyaRaObh7icESTTMMEkdAAKANT4Rajear8bNIvb+4kuLmU3DPLIcknyJK82rv/AIJf8le0L/t4/wDRElcBQB6t4E1LRdK+EniWTXLea4t7m9igEML7Hc7c8N2x1qHTNP8ACXji21Sz0nw5No8llZvcQ3hunkyyjOJd3ygH2Arj9A8WT6HZT2Mmn2epWMziVrW8VigcdGG1gc/jV/VviJqWoaM+kWNjp2j6fKu2aHToNnnD/bYksfzoA5A9a9A+Nal/jDrajkk24H/fiOvPq9B+NhI+L+uEHBH2fBH/AFwjoAoeGPF154Aa/itdNgOsO6qtxMu4wgdQB6n1rZ8V6Ol58PLbxZqVkLDW7q9fIMzsbuNud+1iduDkYHFY2n+PVsVFyfDulXGrKFC386MzcAAHZnYW4645PJrn9a17VPEV817q17LdTnoXPCj0UDgD2FAGbXoHi3/kkPw6/wC4n/6PWvP69A8W/wDJIfh1/wBxL/0etAF/wB4v8N2/gzVfCPigXM
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAABnTklEQVR4Ae3dZaAly1U+fNzd3d3dZYbgrsGCBg1uCUGTe9EQCBJcQtDgEAhO4A7uGtzd3fX/O+fZ95k11b33OXOT99M7/aHPqlXPkvJVtbv7PPb//d//Pdat61YN3KqBm6mBx7kZ8C3srRq4VQPnNWC1ucyC89iP/djguU9itxYLa+600twShV2GIOV6nMd5HPfg0YjHe7zHq5UCSgTvHgI/V5LVM3PDBFuYsYIJkHsINB/Qj/u4j0sqzIOZO5PJffzHf3zgXEHWCkCuWKEwgDCJLEn8yUG7SPEhYEkYNM4ClmyNoe/CFeUVXJLlTyv1qmBE6YpsOcna5Yc5rRTcGgvGfSFqMQQ8opgky0EcrGyNRT6SoXsvs+qatUsUdswKKTqrdqtkFmObu8jWSvjpOjVRJoLamUQH1juizqfDSYYzrYDVSohg8F3oaEbH6Dn7htsUn+BaiaysXNWZ5NRV03diz4rpkiwMHU6IaaWYEtHTJMHSJabyMhciVnaRZZZYZGdyi5mcpSzJci8xVaHDX5gzmXZfkIey+LNkTMnHFD2tcDfJXeWncyuyC5tWIFMvmAgXwoWvOhD/+7//G23JlQwRTO44wUjqNBFZsgBkTSmcmEuWu9wAXvqlX/r1X//1n/3Zn/33fu/3HvGIR/zqr/6qXP7QHIBkriTjEo5kdFZzYDEtN316FqoaQgQ/lRD5n//5H3zM5J64B1M3prZKLb7hh1NPinz0idpSirRdyh4/p/4C4kaSAYTe9oelmMA1FyVnQcVNXeSf+7mf+33f932/6Iu+6Ld+67ciW6WXURXDu8jqeYIneILXfd3XfamXeqlr164JOX7zN3/zj//4j6etE0rAUh3FlEhW+0r4uRNB5OJGPJGsn6Hxywl9kDnv1s0KuDo1DKNv/uZv/t7v/d4v/uIv/su//Ms/8iM/8vd///fBx5/KLkR9mC5NNwIoLOJNlsCPhsp2jE2LwUzOonBm7SqfzElPwbtAL47RHE5KUUNhTnAGVQGz1KGX+i9yOrkypVfWgDOfVIknfuInfp/3eZ//+I//eOhDH3qve93rhV7ohQb8KHnaymJC8vmf//kf+chH/uM//uMP/MAPfMZnfMaLvuiLTtV1ZjLR08oWE457r4g/0RM90aInSbASaGMgyViZSooMJsnQlXryJ3/yH/zBHyRrwLzBG7yB7U2NTqmolbWUpXoqhcAMvw7M3EkDQFa5rIgsVqbI5emoyn1KlRMrkvFhYkKnFJNf2cncpYtsWcqp8q3+qep5n/d53/AN3/Cd3/md3/M939N68E7v9E6muQmY9MFKjc28E/RTP/VTP+QhD/nv//7vCFpzMnIWX6OhzNNWwGajPv3TP/1Hf/RH/+d//mekDBtRTV2KzmouH7G1UjAi9MRb017u5V7OCnD3u99dxb3iK77i0z3d001A6ci687NWzjSeD63cgZsspxo0yZ/8yZ/83M/9nCCtTLAUfOFIxsrUM2mAdIXcIx56wqK8nCQnuGUJM8jiw3Tfcpq1Syz4xcoisoCX5AJO8imf8ilf8iVf8jVe4zWe7/mer4BaoSFXsyRLb4n73ve+ouWIu+vbr/3ar72FhXOwUmO4p7VH7Bme4Rk+/dM//ad+6qcsOGT/+Z//+RM/8ROf8Amf8JiZG4zdCDpm7sVe7MV+/ud//t/+7d8U4M/+7M/e5V3e5Sme4iluFN1PpSzpHJ0w0pkisNSmAf9jP/ZjkXK33/i4j/s4zKw/cW/2vIgHX4WI3YJMuwKzv/3bv/393/99I3O6HoUVTzKCsRJOTQRZ/PRhqp10/V+kUj+zLFNql9bK1swne7Ine9qnfdqneqqnEjxHecHTxKQXKzNr0tFDZ9f/RX8Az/qsz/ohH/Ihv/M7vyPKvfe97x2m4sQKgs5cdew08UEf9EHf8z3fQ+Ef/uEf/vVf/zU9P/RDP6SYu1KHsixFAmVyCixJWU/yJE9ioBuggkLiInUL3BRU2qkBvbUy8QvYfCw8E3f+13/914Me9KBnfuZnXgCR3ToWK6lruXUD7VqUPNMzPdN97nMfIqwYn+5oc8Edd9xxj3vcI72KVJSUoGS3LFEeK4sthq5du0bz537u5+6WJS3NUGzFdKwsqmIlMPRiLvxdEczJL70tS7POBEalCQHe4R3ewcpvbn74wx8uRH/Lt3xLsbQZbcImXQ9PWEmJercFEJBb/Nv7mxVCcPvZn/3ZFP7CL/zCm7zJm7SZ5F7eytQZhzPvv8iLvMhXf/VX0/Mv//IvM8CZ+IOVrbEJOkZroVd5lVf58R//8Yj/xm/8xtu93dsdA+MHxsV4uYtMq1t/P+zDPgxeP/7d3/3dV3u1V0vVTJFjSmZZtlJTg+n/y77sy0wtrFjWTDaCqNtvv11BKPmlX/qlt3mbt5lz3pSdVs4L9NhsueJV+3QIs/J7vdd7GZa2NM/xHM8x9Sx4WeG4k40VtKtSpWsunN4LqEjcmPdkBTnLUpEJQOvKhofR4gQI/l//9V9V2r//+7+rN/PmV37lVz7ncz7nIrskT1hZkIai6fjP//zPHZ8sWZJKLai2blP4Xd/1XVevXsUpbFpRuhQwuZMufku81Vu91aMe9Sh6/vIv/1L33krhHKxMY1WUWm7yGPFRH/VR1jUahGo2PMdg+LtWdvFmek5rGHsb/XiGZ9tiVEOyakXSpRQhwEqgX+mVXslkmSDQKfBrvuZrWgTsc0Qgb/u2b5uw7S/+4i9MPC/7si9bEyVihcL03fZInGJKCL5/+Id/WEu/+qu/eo8B4ox7PATeyrYsVRWpJhepqIozE7NLx7SsrZUF/2zP9mxf8AVf8Gu/9mu6ssEPr2kSZUTWRGPG2fo/PamVLWyaM2PqUZZl0ZdjpzkkArPi2UtzQCTy4Ac/eB5H0VwrAW9tbTnTul7xnd/5nSmgcN3eaesA/MFK/pzWOLVP+hmf8RlFawYMJb/927/9eq/3ejN30gdj56zY2rVouP/0T/90wM7Qnuu5nmsqCV3BEsVEsPzZcsH42eQzP/Mzf/EXf/Hv/u7vzJdssThrJwffP/ETP0GVzgE5d/BRuFhZBsNi1DgUcOp5jZW5Fw9z5xiR0Pl1/yx70wnAMANOWU7cLdGC2+/+7u/WD770S79ULMoBW0Q7XavoK7/yK1dbynJMlfDS/PI3f/M3gbmLYC0IpgBnNr/yK7+ii+vB1urneZ7nOaYE/7SVCvq94du+7duA/+AP/uCN3/iN1Yas1qdm0sH+4R/+AeBHf/RHX/VVX7WCIS5jJXVYwbOKPq9VC6YqspAqjo5h6+54rbnwgSEOVi40ZidTM7uESNTGnbHP//zPbyEX5IVW4F/+5V/eygvJe4SjrUXJhcnFiqKm6glyzCnZF3/xFxswgZlB9QAzXNXGebHZ3e52N5VozOgWhtYDH/hAs0Nh00pqs3UKg25SrPx1X/d1TgKuXr0aZrMQpdmN6ZpATCuSAUypSVfQT2qf9EmfZH0zVqPBXf8W7iqOgOqrvuqrbFHgYz2YelI9ONbJL//yL48ejSvIvP/972/ABCz4NCwjTnk2txVfiMDK3BZWloFn0/JXf/VXwJrmBV/wBRevNN+XfMmXyNUo7/Zu7zaPbYJcrNTchYTfA77iK77CckqDudIvhG//9m//pE/6pIvgtHI2oJdLNnlMxVMYfevXf/3Xf+ZnfsbYCH/B2z8JqMzTL/MyL7NkXT5pTyY8E1ASsdvT9jZO8YQb5oBFlTneWgcgmmK9Pi8wcUU4APS/6Zu+acaJVVgbfOu3fivNkXWPFV3EQmfCM0mbcgTTNql2Qbb1MK6aQKsQdx6q7hqqt2Z91/d///dnCQVIBVawqqYbi7ZZ5zWHcFVcJ7ZxErT4zcE8jW/qsUn4oz/6IwVRWEkTqrpSEDu3ehINU1U4Oo2fmy1QknyzCKgupehPtJhdpVWyJojgck9JF+Y2qfVFeuZfY4MzprY//dM/nV5pNeu20EA9G65mVc1ETzATiTmNol1tkZjGmSJ+q3G6kCyVb5r72q/92iT
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Train :: Epoch: 26/30: 100%|██████████| 469/469 [01:50<00:00, 4.23it/s, Epoch Loss: 0.0168]\n",
"Train :: Epoch: 27/30: 100%|██████████| 469/469 [01:51<00:00, 4.22it/s, Epoch Loss: 0.0168]\n",
"Train :: Epoch: 28/30: 100%|██████████| 469/469 [01:50<00:00, 4.26it/s, Epoch Loss: 0.0167]\n",
"Train :: Epoch: 29/30: 100%|██████████| 469/469 [01:50<00:00, 4.25it/s, Epoch Loss: 0.0167]\n",
"Train :: Epoch: 30/30: 100%|██████████| 469/469 [01:50<00:00, 4.26it/s, Epoch Loss: 0.0167]\n",
"Sampling :: 100%|██████████| 999/999 [00:26<00:00, 37.39it/s]\n"
]
},
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPq2laTrP2exg8ny4vssL7d0KMeWQk8knrXH/8Lt+If/Qw/wDklb//ABuj42/8le13/t3/APREdcfoekXOva1aaXaIXmuZAi4HTPU0Aet+FPGfxP8AE9je6nJ4sg07SLIYnvbiyt9obGdoGzk9PzFcvN8a/iAk0ixeJPMjDEK/2G3G4djjZxUvxO12ytUtfBOgsBpekjZcOqgfaLkcO59eePrmvN6APcPEXxS8ZWPw48F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM964//hdvxD/6GH/ySt//AI3R4t/5JD8Ov+4l/wClC15/QB6jonxP+K3iPU49P0rVnubl+Qq2VvwPUny+BV7xP46+L/hC5SHWNW8rzMhJEtbZkcjrghO2a6LT/Duv+Bvg9HfeG7KSXXtTYSXVxCqu8EBUkBePofxrwm+1C/vpWN9dTzPvLESuThj1OD34oA9l8O/FHxnffDjxpqtzrPmX2m/Yfskv2WEeX5kxV+AmDkDHIOO1cf8A8Lt+If8A0MP/AJJW/wD8bo8Jf8kh+Iv/AHDP/Shq8/oA9BX41/EV2Crr5LE4AFlbkk/9+63/ABF44+MXhaO3l1XUnihuEDRyiytypz2z5fB9q5X4fafpUFxP4o12YDT9JZZEgVsPcz9UQe2Rkn2ruh421Lxl8PPHepa1IDYjyUs7YoNsTF/lCnGc+9AB8Lfil4y8R/EfStK1bWftFjP53mRfZYU3bYXYcqgI5APBrj/+F2/EP/oYf/JK3/8AjdHwS/5K9oX/AG8f+k8lcBQB3/8Awu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N1L4D8FaXPpM/i7xbI0Ph60baIQCHum7BSCOM8cVc8SeN/CuteDbqwh8FwabdK6iwuIYwp2Aj7zgAk469c0AZ3/AAu34h/9DD/5JW//AMbrsPil8UvGXhz4jarpOk6z9nsYPJ8uL7LC+3dCjHlkJPJJ5NeIV3/xt/5K9rv/AG7/APpPHQAf8Lt+If8A0MP/AJJW/wD8brd0vxz8aNathcabLe3MJbaHj02DBOM/886z/gt4LtfFXiaa71JPMsNNUSvGy5WRjnAPtx0rqdb+L3iDxB4nTwt4OigsIJJfscMjIN55xkdlGB2HFAHGzfGf4j288kM2vFJY2KOjWVvlSDgj/V11niL4o+MrH4c+C9WttY2Xupfbvtcv2WE+Z5cwVOCmBgHHAGe9cz8T/h3F4I0/SLiXUZrvVL9pDeF8bdwwcr36sep7VV8W/wDJIfh1/wBxL/0oWgA/4Xb8Q/8AoYf/ACSt/wD43XVeIfGXxS8OeGNF1y58TxtHqasRELG3DREdM/Jzkc141bwm4uYoV+9I4QY9zivVfjVcR2r+GvCtnG4TTbEOyhicvJjjHqNp/wC+qAMX/hdvxD/6GH/ySt//AI3XYeHvil4yvvhx401W51jffab9h+yS/ZYR5fmTFX4CYOQMcg47V4pNbT2xUTwyRFhkb1K5/Ou78Jf8kh+Iv/cN/wDR7UAH/C7fiH/0MP8A5JW//wAbrtfC2s/Grxfpn9o6brVutruKCS4t7dNxHp+75rwqtZPFGvR6dBYRaxfR2kGfLijnZVXP0NAHo3ijx/8AFzwfqIsdY1fypGGY3Wztykg9VPl81qfC34peM/EfxH0rSdW1j7RYz+d5kX2WFN22F2HKoCOQD1rI8TNf638CdF1fWpZJr631BobeaU5d4Sp6nvyOvXisb4Jf8le0L/t4/wDRElAB/wALt+If/Qwf+SVv/wDG6P8AhdvxD/6GH/ySt/8A43Xn9dx4B8OaBqNrqmt+JbmVdN0xAzQQnDzMei56/lQBaX41fEV3Cpr5ZjwALG3JP/kOrOofFn4o6VMsN/q0ttKyB1WWwgBKnof9XXWeIdW8NeA/Cmg654W8O2Vtqmpx+bEblmmeCMjO4Fj1zxmvFtV1fUNbv5L7U7ya6uJGJLytnqc8eg9hxQB940UUUAfIHxt/5K9rv/bv/wCk8dSeBFuNH8H+IvFOmxGbVINlnAVGTbrJnfJjvxgdOKj+Nv8AyV7Xf+3f/wBJ465nw14p1fwjqYv9Iumhk6OvVJB6MO4oA9S8F+AdH0W30/XfF8U2o6hqLLJY6XGhdmzzucd85zzxXCfFKw0vTPiPrFppCIlokinZGfkVioLBfQAk8dulbet/HPxbq8LRwCy00yLskks4iHdfTcxJH4YrzV3eR2kkYs7ElmY5JJ6k0Ad94t/5JD8Ov+4n/wClC1xml6Tf61erZ6baS3NwwyI41ycV2fi3/kkPw6/7iX/o9ai+GXxDT4fahe3EmlrfLdRqmRJsdME9Dg8HPPHYUAet/DbwR4o8Jqdb8Ua/dQWNnAzJp/2lnjC4z8wJwMc8Ada+e/EGoLqviLUb+NQqXFw8igDHBPH6V0/jj4q+IPGxEE7rZWAyPstszBX54L5PzHp7cdK4agD0Dwl/ySH4i/8AcN/9HtXL+GfDd/4r12DSdPUebJks7cLGo5LE+gFdR4S/5JD8Rf8AuG/+j2qh8P8Axna+D7vUTfaUNRtb+1a2ljWTy22kjOG9KAL+sfCjU7W1N5oV/aeILVJBFI9icsjk4wV/rU/jL7F4U8F2Pgy2kEmqPOLzVmRshHCkLF+G7P4VLd/GC8sEW18HaTZ+H7JWDERIJJJPZmI5Fcv4l8YS+KC0l1pOmwXUjB5bm3iKvI3qecfpQBufBL/kr2hf9vH/AKIkrgK7/wCCX/JXtC/7eP8A0RJXAUAeq2Pxsmg8OWuj3nhjR72O1iCRedFldwGNxXpn6YrX8S+KD4p+AxvbzTrG2mTUFht0tISiRgHnGScZHpXidetaa9pqf7O99ZyXcNvNY3xlUSEZkOd2AOpODigDyWvQPjb/AMle13/t3/8AREdef13/AMbP+Sva7/27/wDpPHQB6V4Te38B/AC41iT93d6irsrKBuLMSqDJ9hms34IeBxbxP431h2ghiBNruxhlGdznPbsOnSsXwt8VtGm8JxeE/GmkC60yJVSOWHO4AdCRnOR6gioviN8V7XWdHi8M+FrdrTQ40VWZl2O4H8OM8L/OgDmfiZ40k8a+L7i8jkf+z4f3VnEx4VB1P1Y5PryB2q74t/5JD8Ov+4l/6ULXAV3/AIt/5JD8Ov8AuJf+j1oA5Lw75f8Awkul+a+yMXcRZj2G8V9E6Z4dOq/tCavqWpw74bKCOWzJ5U5AAPvj5v
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAACKCAIAAADpF1LuAAByfklEQVR4Ae3dZaBlyXXd8TjMzMzMDIrTIwUcZma0A1biMCkezYRBYXKYmZkkjSbMzMzMDHZ+7/1fr64593VL/j7nQ91du9Zeu6pO0akD94M+5mM+5uO8ebxZA2/WwMemBj7uxwb8JvbNGnizBu5rwGyzCeeDPuiDqpUJZyV9vI/38T7ux73rZlIdyQNcotMnnF4uSS82PMGcntFb+eLl43/8j3+acHSX74PklLHdJ965UNLIl7cTmZc0MzkzM6s4l3SCw8ihVLJD6qKUK0tWAU6Gk3ZyAvCE55l8gk/wCcDyQnDM0dho6NVGGRNWM/RhTiE5E6kR1mbm5WI1Ry8Q5qI8PA/J3WosDM2EMiM6tpPnVskXpWMkwz94uThb8sVGdEmES/Q26eLveV4yfL9s47/QTp9wenke56P6VevJ/yiSo7yUWrOgZFiTShZ1hCGUvTQpL0kBshVKzcsdy1OekdCEXNKyEaZUYX3j9LVMSj29zDAhk8I0OR3sEj2tyDMnOM7zMoaSFr0It6k0iiMc8iwL5XleTgwTuQWe7TI/zYQZTjgN5+UNoxot+x1FZ38Kc/yVvtJXevXVV3/Oz/k5b3/725VqJkjIwTi+LVKAk3PyC5LCRHsLW1MDk7pMFh3/KYB1nMpTlnpGca4RkC+pkGnGOaEBDGC5Cnme+yWBnWU5M5DVXI8/zMlGc0lNIwwmdWWhxOkgJCecDF/yS37JX/ErfsWP/JE/8hN+wk8otaSZnHhKqcOcXoIJwwSLJHmAW+FSugvgUS9hlsnTxaOuUwoT5kI0krzcLRIcA9GemkXDLErA8t2+23f7Dt/hO3yZL/Nl/vk//+fvfe97/8//+T9gYYQ4P/qjP5rm//2//xfnzAdLfwkHw0D+ZJ/sk33lr/yVv+yX/bL/5t/8m9//+3//v/yX/zLaYKpSdCaolkoZw4X/jC6TX/7Lf/kv/aW/9D/5J//kD//hP7wMQ8Y8nsjp57fapKcB+7//9//SBNNVUMUw4fROvvV1ArJFSHBwkYCfL4doqVmdbDRSx5ZtymDMl0qOkyZ5SZF8ok/0ib7CV/gK3/ybf3Nn4U/+yT/5R//oH/3f//t/n0gyZCGq5Xk8FyGMkL6Q+Wf6TJ/pLW95y+f7fJ/v837ez1t9Ur7nPe9xUv71v/7Xsh3/hUqUfsozJyXFT64sBKfGmYJcUsgAg1F2pFGoZ/GBpj0zEe5W852+03f6u3/377L9m3/zb37v7/295WPIeMZGD+Z4cPnGQj5Pmccv8AW+gOHtj/2xP/Zf/st/+Vf/6l/9ql/1q/TVz/JZPsuZH3LRvCSPNqHMlOR8LG8DmzP/8T/+x6+//vqHfuiHfrbP9tku5qJD5iXAlCeecnq+zqTJYWoZUxI2Y1/KsgxPAM7LqfmUn/JTfr2v9/W+/bf/9t/km3yT7/Jdvsv3/J7f04izPMAHzpD+LMt4ytuZK7LB8df+2l+rqWm7Vhaf/JN/8gAnYZqRi5Z6eglzhuHTfN2v+3X/3J/7c//xP/7HTAr/6l/9qz/qR/2oL/bFvthpRT4NRedlBTnxF3DRwp2FE6PrfrNv9s1MsJ/4E3/ifJX64GXO8nFapjnrZRlydhWGrfA7fsfveOZvJ+kkvHg58S+QNd+f9bN+1n/7b/+NuROmWf/9v//3/8bf+Btf5at8lUfzmRdJ8rmsnvxn3k6Gt771rYbPzP/9v//3P/Wn/tQv8SW+xAmIJM6VBaBDKuFRj5Kep88q5kIki+blebYhT/yn/tSf+ut8na/zg37QD/rbf/tv/6//9b9Mm2YDJL/oF/0inHguZWfryEvtZr7omQinEVVF7373u+H/5//8nx/xER+hf4ZZSHBkkm0e5yVamIRLlN4QabVflv7Tf/pPOs/6z3/+z//55/28n9eYcufmcBQPvxmOlmZyY/qiMRRSXvQG5a/5Nb/mL/gFv0BL0xI+z+f5PEMS8vJsihDPNxZyIdwm0GR1YaT51t/6W2vT/+yf/bMf/+N//K/7db8u3sw3k8aAqtQXhPN1Yj7jZ/yMP+AH/IDv+l2/6yf9pJ/0v//3//4n/sSf+B2/43f8i3/xL774F//i5usTmbzZlkeZ1M2GWZ0OI+nMmIH5K37Fr0gp89qfxif6Lb7Ft7g4Ok1ikHOC0FEqgTs86ScrhRWIBac8/Nf/+l8tdbS/MKMawxyV4cgLFW3kwbj47J/9s3+7b/ftfsJP+AmoHAA0BAtaK2cAtrOiZxjbHfrpKU4Wnt7B0vy9v/f3zAMvvfSSLDnvsVW6sQFLjTnlyTba4fO4qOtkc+N/+A//QZ//M3/mz6CSeUPkF/yCX/BTfIpPIZXw1/7aX8uqmhlnUUmylDwN5cBnqeedYAmqa7H9/J//81vOaA8a+Z/+03/aOZIUUlht3EU5dohvbDiTERWdAbof+AN/YFavvPLKp//0n/6O5f4ADi+2NRuZMnywNJPHPA1BNf3QH/pDZZqhCyed07XHZ/7Mn/nEnHJ+Ty8n7SmfVmRJn+EzfAYXZtoBd66dfvWv/tXxmKM/ySf5JDDIYyg8vayYK3j8kLMyRhrAvsf3+B56crZ/8S/+RScmkzjLyXxRnl46NTEPOcMnT55Y/cNrHErBi0OD+4f/8B/+lJ/yUwzhF8Oijdx5GZUkvpbz9Hn/Il/ki5i44HXFVmgr8oQ5OgXmebnA5nSlM6Z842/8jb/f9/t+xpcxSFUKDK4InJHpJ4xnZZlmmFNYqnPncIoVx7LWGdGw/+yf/bN4uorWME5DMtu8PHSbVMJL2RadM3PWP/2n/5Txb/pNv8k19EwAbsGlPnMm/sJjXnR66zGn/3f+zt9peTDmR61n9VCkp6PC9BOYoypa2FmxBuDrB//gHwzwhb7QF1I07e9H/+gffS5Fcs0qL2eWKDuWvaXqMO94xztclXGRIWbCX/pLf+lrfI2vEX5ZkpkMhYHRwghpHBcvkswqP/2n/3RgtLoNwVTwy3/5LzfouLbeIPrpPt2ny1eEozq9BDhh5PCED/uwD7N8VUvGe+YnOJmyDjCT9KJ5ET2TJhNOwunnQuat1ZH8wT/4B8912ggzedTLSBIgAzuz3/Jbfsvf83t+j0FZjbly/h//438oHcHQ+WN+zI8xcI9/tgr44GXOLg4ejRrvX3755U68q7eVoaycbjox03ysvBggXX1mYqfuMkveF/yuMd0emSwzF8Ct3kjzvb7X91JxDH/JL/klBlQmrgLf9773KaMlovHvQiJ6luU835OXQ4OZjROrjkx+xs/4Gb/sl/0yI5noP/gH/0AFYitXp23KTC55Xq0uVz/iR/yICGWY8Ht/7+91VfZpP+2ntSbMVmiTwPrKxZuiMcxXqXmhrMVf+GGCabh1TteZKke5loHlPE3ReMY5L7MiPGo4wFynsWB2neNMfY7P8TnSZF720tx6ubgYucWeiyhjWZd/Vhk6DHMXAuYcCygVyDby0wWGBy/9SDt9DDqBgVqwP9NJcvpbPZ+AZWvCUm+LJGmpwycYhhWJyV/5K3/l0U2tZTWGe6aHJTirk+3i4oy6+WB36C/8hb/AxDaaqizVNiulVmgH79JtAsA7eCkbQsfJnKzzG7T++l//68DOiv75uT7X51IcixCjptO29W3NS1tkmC3yeyd3Xk7lUjP5tt/227aoADab6ULndhOwRmb5oRqbiF566aWyPc68yH/6hJzmguxQOfYwgV1bfu2v/bUHK1XUMU7KNIRIVpbwLwhXwAvG8sxcp+1pG8il3iLzMn2wwpPNWXjXu97V+t8OrcWzrY7WAnYdWn/Cz1ARKkUkD2W5FCmvhZf8fdbP+ln5g3eqLM/OIWd5Pf2d8qVI5eAMx2DjLy9GAgMbpyfslGcy5VmW29TBEjQpi3VebGlYSi3VZoD1oW6jfZwTHQBOpT697NwQHGf9/rAf9sNc+uOxHrPC6RYhElsO7kGZE+YxIYa8CE8vlWUhIfmd73ynxgRp38nCTIMYFYGLX/gLf6F+6zqn3UglGoC700uE9y3kbsd80fAux+3
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x138>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-04-08 11:37:01 +02:00
"source": [
"for epoch in range(1, total_epochs):\n",
" torch.cuda.empty_cache()\n",
" gc.collect()\n",
" \n",
" # Algorithm 1: Training\n",
" train_one_epoch(model, sd, dataloader, optimizer, scaler, loss_fn, epoch=epoch)\n",
"\n",
" if epoch % 5 == 0:\n",
" save_path = os.path.join(log_dir, f\"{epoch}{ext}\")\n",
" \n",
" # Algorithm 2: Sampling\n",
" reverse_diffusion(model, sd, timesteps=TrainingConfig.TIMESTEPS, num_images=32, generate_video=generate_video,\n",
" save_path=save_path, img_shape=TrainingConfig.IMG_SHAPE, device=BaseConfig.DEVICE,\n",
" )\n",
"\n",
" # clear_output()\n",
" checkpoint_dict = {\n",
" \"opt\": optimizer.state_dict(),\n",
" \"scaler\": scaler.state_dict(),\n",
" \"model\": model.state_dict()\n",
" }\n",
" torch.save(checkpoint_dict, os.path.join(checkpoint_dir, \"ckpt.tar\"))\n",
" del checkpoint_dict"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 22,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:36:13.403286Z",
"start_time": "2023-02-23T07:36:10.026409Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:03:08.012329Z",
"iopub.status.busy": "2023-02-22T16:03:08.011929Z",
"iopub.status.idle": "2023-02-22T16:03:10.213564Z",
"shell.execute_reply": "2023-02-22T16:03:10.212573Z",
"shell.execute_reply.started": "2023-02-22T16:03:08.012295Z"
},
"hide_input": false
},
"outputs": [],
"source": [
"model = UNet(\n",
" input_channels = TrainingConfig.IMG_SHAPE[0],\n",
" output_channels = TrainingConfig.IMG_SHAPE[0],\n",
" base_channels = ModelConfig.BASE_CH,\n",
" base_channels_multiples = ModelConfig.BASE_CH_MULT,\n",
" apply_attention = ModelConfig.APPLY_ATTENTION,\n",
" dropout_rate = ModelConfig.DROPOUT_RATE,\n",
" time_multiple = ModelConfig.TIME_EMB_MULT,\n",
")\n",
"# checkpoint_dir = \"/kaggle/working/Logs_Checkpoints/checkpoints/version_0\"\n",
"\n",
"\n",
"model.load_state_dict(torch.load(os.path.join(checkpoint_dir, \"ckpt.tar\"), map_location='cpu')['model'])\n",
"\n",
"model.to(BaseConfig.DEVICE)\n",
"\n",
"sd = SimpleDiffusion(\n",
" num_diffusion_timesteps = TrainingConfig.TIMESTEPS,\n",
" img_shape = TrainingConfig.IMG_SHAPE,\n",
" device = BaseConfig.DEVICE,\n",
")\n",
"\n",
"log_dir = \"inference_results\"\n",
"os.makedirs(log_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 23,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-23T07:42:49.677019Z",
"start_time": "2023-02-23T07:41:04.890036Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [00:49<00:00, 20.03it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAESARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPq2laVrP2exg8ny4vssL7d0KMeWQk8knrXIp8afiNIwVNeZmPQLY25P8A6Lpvxt/5K9rv/bv/AOk8dR+CPiDZ+C9Dv44dCt7rWZpN0F7OAwiGAMYxnqCeD3oA3L/4g/GLS9Oj1C+uru3tJACk0mnQBSD/ANs6x/8AhdvxD/6GH/ySt/8A43Wl4T+KniHUfFKWXiC7XUdM1RxBc280SlAG4yoxxj8q4HxHpR0LxLqelEki0upIVLdSFYgH8Rg0AeueIvil4zsfhx4L1W21nZfal9u+1y/ZYT5nlzBU4KYGAccAZ71x/wDwu34h/wDQw/8Aklb/APxujxb/AMkh+HX/AHEv/Sha5PQfDuq+JdRWx0izkuZ26hRwo9SewoA9A0L4l/FfxJetaaVqz3EiIZJMWVuAiDqzHy+BWdL8aPiLDM8TeIQWRipxZ25GQe37uvR5vhz4g8K/DGXSPDkSXGtai27U545gGWMDiNO569eO/rXz5eWdzp93LaXcDwXETbZI5BhlPoRQB7T4d+KXjK++HHjTVbnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/wDC7fiH/wBDD/5JW/8A8bo8Jf8AJIfiL/3Df/Shq5zwj4T1Hxlr8Ok6coDt80krfdiQdWP+eTQB1dp8X/ibfyiK01iWeQ/wx2EDH/0XUU3xn+JFvM8M2utHKhwyNY24IP8A37rp/EHja0+F0b+FPBMcD3CL/p2pSDe7SdwO3H6Vn6PY6T8VtKuLJBbaf4xMxuWnkLEXo2nIHPy9s0Aa3wt+KXjLxH8RtK0nVdZ+0WM/nebF9lhTdthdhyqAjkA9e1cf/wALt+If/Qw/+SVv/wDG6l+DlvLZ/GjSLaZdssT3KOvoRBIDXndAHoA+NnxEYgDxBknsLK3/APjdXbb4qfFa9fZbajcyt6LpsJ/9p1jfDLxJoPhfxT9u1/Thd2/llUbZvMTf3gp4PpzXpkfxt8R+IfEMWl+DPD1s0ZO1FnVmYqO52kBRigDhLn4x/EuyuZLa61t4Z42Kuj2MAKke3l11XxS+KPjLw58RtW0rSdZ+z2MHk+XF9lhfbuhRjyyEnlieTUf7Qklj9s0WOWKL+3fI3XckPClewI69eme1cn8bf+Sva7/27/8AoiOgA/4Xb8Q/+hh/8krf/wCN1taN4/8AjH4hgefSLq6vYkYozw6fbkBsA4+56EVgfDT4dXPjrVGeZmt9Htjm6uBx77VPrj8q7Lxv8Rrfw3BH4Z8CW/2TTrOVGe+gPEzAAld2OfQknnFAHMXXxj+JdlcyW11rbwzxMVeN7GAFSOoI8uuq8RfFLxlYfDjwXq1trOy+1L7d9rl+ywnzPLmCpwUwMA44Az3rzPxv4tl8a+In1ea0itmaNY9kffA6k9zW54t/5JD8Ov8AuJf+lC0Aa/hz4k/FTxTqbafpmvRtcCJ5sSWtuo2qMnny+tZUnxp+I0Ujxvr+HQlWH2K34I/7Z1e+G0C6L4M8WeLZoUZorcWVozf334bHvgivL+pyetAHoH/C7fiH/wBDD/5JW/8A8brsPDvxS8ZX3w58aarc6zvvtN+w/ZJfssI8vzJir8BMHIGOQcdq8ss/CGtX3hy+16KzYafZbfNlfjOTj5fWuj8Jf8kh+Iv/AHDf/R7UAH/C7fiH/wBDD/5JW/8A8bo/4Xb8Q/8AoYf/ACSt/wD43Xn9bfhHQX8TeLNN0dQ225mAkK9Qg5Y/98g0AdS/xm+JMaK7646q33WawgAP0/d11nwt+KXjLxH8R9J0nVdZ+0WM/nebF9lhTdthdhyqAjkA9ah+PviCxe40vwrp5Qppq7ptq42NjCqO3TOfwrlfgl/yV7Qv+3j/ANJ5KAD/AIXb8Q/+hh/8krf/AON1v+G/HPxi8WT+XpGoySoDh5msrdY092by+K8grpIvHfiK38LReHLTUHtdOjZmK242M5JydzDk0AeheKvGHxh8HLbyarrUXk3GRHNBbW8iEjtkR9a5n/hdvxD/AOhh/wDJK3/+N1r6sX039nvSra8cySX980tsrLgxICSfwP8AWvKKAPv+iiigD5A+Nv8AyV7Xf+3f/wBJ46d4P8BaXfeGbvxV4q1CfT9FhfyovIUGSd/9nIP0/Om/G3/kr2u/9u//AKTx1J4c+Iuk2nhO38OeJPDa6vZWkjSW224aIgsSTux160Ad34M+GPg3W9TtNf8ADus3k1lZXKGaG8jA3Ec4BAHtXlXxIN1L8Q9fluYJIib2QDeuMgHAP0IAP41o+Kfibf63YW+kaTbro2jW4Gy0tm6kE8lup61zeseJ9a18RjVNRnuRGoVQ54wBgZA6n3PNAHU+Lf8AkkPw6/7if/pQtc3onivX9Bsruy0bUJrWK7x5whUBmx0+bGR1PQ966Txb/wAkh+HX/cS/9HrVDwZ4/fwZb3C2+haVezTEHz7yIuyj0HPSgA8MW3jnW9Xgm0aXVZpklUGfzX2oSf4iT04Oa6n4+LZJ4n0xVdW1UWS/bygGC3QEkdT1/ACszUvjb4uu7b7NYPZ6Tb4x5djAFA+hOSPwrz25uZ7y5kuLmZ5p5DueSRizMfUk0Ad34S/5JD8Rf+4b/wClDVu/CgXNh4C8cazpwf8AtCC2WKNgwGxCCWbp1AGfwrC8Jf8AJIfiL/3Df/R7VX+G/wAQW8D311Fc2a32k3ybLq2b+LggEZ474II6E0AYnhzwtrfjPUpbbSrc3M6oZZGZsAD1J9zXf+FPhvqXhjX7TW9a13TNJGnTJNLG1xmXaDypAwRkZH41cl+KXhTwzompQeBdGnsdQvsE3EpyEPfAOenOBXjt1dT3tzJc3UzzTyNueSRiWY+5oA9s0LVdC1r9pjTr/wAPrizlWYuwUqHk+zy7mA968d0fRNT1+/Wy0qymu7lgTsiXJAHUmuw+CX/JXtC/7eP/AEnkqv8ADn4i3Hw/vLqSOwiu4rpQrqx2sMHs1AHpnhT9niBUhufE168j8MbS3O1fozdfyxXWa74m8EfCTS3sdNtII75l2rbW43SnPIMjHJx9Sa8f8V/HLxP4hiNtYuNJtmGGFsf3jf8AA+o/DFeaXFzPdztPczSTTOctJIxZmPuTyaALetaze6/q1xqWoTtLcTuWJJzj2HoBXYfG3/kr2u/9u/8A6Tx1wFd/8bf+Sva7/wBu/wD6IjoA7j4E3drqfhPxH4WnuHtjK32gyRkbtjKFbGc9Nv61T1L4n6a2vW/hXRdBsr7wwA
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAAESCAIAAAC+Vc10AADkC0lEQVR4Aez9Z6C2WZrWdQNiwPiac86KOWAYtarbnAM6OioqDCqCCAqIYDvV1SrKIIwSBFTAhIABxZy6u6rNOWHOOWcxy/vb+7/rqFXr3s/TNeA36/qw7nOd6zjDyuEK93f8+T//53+Hz67PSuCzEvj2lMB3+vaAP8N+VgKflcBzCZht3j7hfKfv9NK1vuNHFznkL/KL/CJvL0KYAIg3WRnmLdqGGfFot6TTSm7HX4jA3/VoNGQZnJUxCWJmZcxgJaHpzEr8YAsnhegaTJRsetBnXsJcIcxkEYtedAoLH/M7K7Mb8SZb44OdyKyfPpxuzArmpC5wmktd0ohTW8iFw5xWHlPjDDwAIqYw4i1ISS9WHo1xPflTC4FPw09kRTPi0crcndq3uFt9Uz6FE0+q8LQy5OnSI52ewIVhpnYiUqNn5URespJKXYg4W234WcmNMzorJzP6SfVznynEfNQ2JOK0u5IMcFl5VvxSyFMOeZp7k60ULjzFZ2XMi6BznJl7lTiZ6NOZWcF/vE79V6qkLnzElXpGpb5YeTSWZOHp1kmn67JR9Dt/5+/8mPpoZRhqu07/3kRPeYCzQeDMyqNjj9m5MK8qPN2ATyQraG4PUFLR8WMWvprHkKdsGnCWl+XxUjvkZfQtOkMK59KsXNqGfBPx6POFnLf4K7ELU/TtqtIjHGwEpsIRjXPlZbZKXfTrEsNfxGnlqXooisWJ//v//r+Lxp8NgIuzpAh5+H/+n/8nGti1KOasnLBLwxklPnO/3q/36/2Wv+Vv+Z/9Z//ZP/fP/XMwS4oQTv+snLBT7aehLw+nf/ysKCsZxBQNkxtF8+30Z3TE5Un4MSn/v/6v/0sU3xVxlmecU9Xcm5KLSM8pgpPO8iJ6pib+zjvv/AF/wB/wT/6T/+Tf+Xf+nf/n//l/XjqLXoKL5tKsIMLXzJb6aPS0krbpXNIjJz2zMuSrhPH3d/vdfrc/+A/+g5Xz//F//B///X//33/44YdXA3tVMCsvM0OI2oGwtLJX0vL2q/6qv+rv8rv8Lr/Wr/Vr/Xf/3X/3H//H//G/8W/8G//Ff/FfwFQBlQXw8JIwUzLYqXlJiLMsSFHyp/1pf9rv8Dv8Dr/ir/gr/k1/099UrjBLyoTw1ZKKvzAr05+fMX/NX/PX/GV/2V/2f/qf/qf/5r/5b/7X//V/PTMSPv34EyeYZszyMmSaS42JzlBSwl/8F//Ff7lf7pf7JX6JX+L/93ypsDCn/on8or/oL6peTyVLQqwkcxvn9//9f//v8l2+yy/zy/wyTPzv//v//u/8O//OP/wP/8P4rypPVcpPEwP/4X/4H/79v//3/wf/wX/wP/wP/8N/5p/5Z+DLIIAro6cgQNGV8JlKRNTQfMnGJ3sW5tt9O9VC0hA++vSBJ5grn1/lV/lVfuvf+rfWhn/P3/P3/H1+n98nqf/hf/gffvbP/tl/+V/+l/8L/8K/cMlOEFJdhP9EtyGQjDR+lL04Gq5q+LV/7V/7c5/73Dd+4zf+pr/pb6qdaccGob/tb/vb/r1/799L3WyUDbKIMVOLSbPMIFwhg4mm5znl5xvq/uK/+C/+5X/5X16W/t1/998tSTiF1c2kApym04OPecJo0FV+g9/gN/jVfrVfrVHgv/wv/8u//q//6/+1f+1fSwlAxT0N8WOi8TOUM0VPE6M1YmONTsIWsCiCaUy95tf4NX6Nv+wv+8v+sX/sH/uv/qv/KhPCdBbVZ67GVF7Sf60OdMj33ntPNWkcLWj/2X/2n/3z//w/33RhQCVSiRXmf3mJMwfmvBnm5/28n/cr/Uq/koGybpNv+UAKcuDEY55ZmNqksjjmolKv7EhSTb/Zb/ab/cq/8q/8P/6P/+N/8B/8B3rvZS49J3M0H658/fa//W9vhvlj/pg/xvqFoHH/3//3//1f7Bf7xX7b3/a3xfxv/9v/9t/+t//t//l//p8l5f+Vi4/nWzZmBrqiZCxvhL/0L/1L/3a/3W/3Z//Zf/bP/Jk/U/cAJvyf/Cf/SdVgeP5xP+7HqfsE0zDxmMInGx91iWl+JKZEkqr6p/6pf0o5mtBkKfAAI/DRGhbi0UqwgRHRuuKf9Cf9ST/35/5cc3RS/8v/8r980zd9EyUK0QCRuTNkQr6Er1pJbfiTZuhP/pP/5J/6U3/q3/V3/V3an85pYEvDGX7rt37rr/Pr/Dqnua9rJfCKWlQ/lCndj6xBzbKWLRVvwlGAv9Qv9Uud+qMvK3kujID5c//cP/c//8//8x/1o36UcfNRPFiF/5gaByYriLydctHRl/jyZYw2ltGgJSilYFkcJuaZlzepNTIGs7L4R//Rf/TP+XP+HNOy1vvlL38ZX3szu77qCYXpfLHydY1RrZeDqQAdRvP6yle+8sf9cX/cl770pX/+n//nWz/8zX/z33wau4rjbGrZPsGv0ur4+3yf72ONobDYeovUkhgtL3HGf9Rvqv1T/pQ/xbgCT///9r/9b//EP/FP/L1/79/7Pb/n9/xtfpvfhrnv9b2+F8ypoRqK83WtgIU0mSi98PqnAixEdCk9DgRQnr/kL/lL5i3xWTkbR2oXRiRilPm2b/s2o1iCP+tn/Sx5/NE/+kdXhn/L3/K36MCPRRH4kR/nN/qNfqN//V//12G++Zu/+U2Yi3+6tO6UlWUEZnTioq4zX8P8nJ/zcxLX0P/Kv/KvPM0ROaPBTgfOVPRv+Bv+hl/96lfBDMQ/4Af8gDP1i1/8omlAOzc9nPxHOiuvjNAX1PqhOraZ+ek//adb0lQB/Pt1f91f1zhK0T/9T//TptEJLs/Lw4uxT+5zwsMMtoKwn2k205hOzURWvjM3IivThh99cowu3+/7fT8FF9hI/EN/6A/9nX/n39mcoFaMN7JpHPrVf/Vffc4kfjWCGR0B7wIWYsL/4B/8g1lRekIdRg9Ba8eqh6F/5V/5V6wHtIaGHqvqP+qP+qPSRkPupSrmmYsMzbTF+j/yj/wjZjPjGkEjwu/3+/1+Un+v3+v3sqphVP9ZrU0PIivTcxIGjr/oL/qLeKgiTGJnUvT0PCaNA+M6rYhKLYy4OGfW5EvzS9wW+of/8B8+zREn+LRSKs0V4MypeqOJQ6bpKekP+UP+EHOa0vtL/pK/ZEmIUk8lL1b6WZsAPUEaWX5bBf6Zf+afacF2KkXrwX/VX/VXydJf8Bf8BVdS0bQtS8vAaXHMRCw0v/a1r5kESP1Ov9PvNLWDpXP8EbOCk34irtPWj/yRP9KqA/Jf/pf/ZSVoaSSPFmYue8S/5+/5eyTpS6ad1BKffoQNw2nlWf3HJgLHBLY3++CDD+DrOZqgreAf+8f+sQrNiGMH8uv/+r/+n/qn/qlWbjD61Re+8IXZOq2M+SrxR/wRf4Q5oT7JkCHMFKfbAzP0l/6lf+mP//E/3lbnks3V08qcD6lA/oa/4W/glaWEVfol/moVnOU8ALWzspJJ2zCL5sNsGZSNYolrEmpK0jAR0zkrJyYTE7H3c03/CDt2SycNwxpnzEnFKfpiZcYuECgOG3ZRf+Kf+Cf+QX/QH/Qr/Aq/wjSWWtSWgD2j3Zl6lcishLluv2AOb+vpQEP1qzDVv7OL3DsrJlWn21kZZ8SkbJftysD+rX/r3/oe3+N7XCVos/4N3/ANjuwADBbqzMpnmZqHb7LCHEPCLoL2AxZ+OqHs/EP/0D9kmf4b/8a/8WX0r/gr/gqbEDo1dx1gmTqtUDg3LoKIFWZgPUfRmTk1r7KsnC3ct2uaY1OY4KWzKOcd+QD8sB/2w+p1k3oV/yYmqctKelYpCvZRs1Q9XzUlqwD/xr/xb2QCsmvmFg0puqSvSwQ2+fy1f+1faxVgJF0VXLJT+5KXV7N0yciY62Quz5iGIkechkyHHifmpGcl85e2kFb2Vila7X/6n/6n8JYulhZmVUcf1zpteRiRhlmh/0rSgBzK1WcUkDmnDpknUufS7/F7/B41F9r+9r/9b/+tfqvfKuVTOCvLoKQunOnBUUpOt933sAK0uB2+0nO29if8CX+ClmGAoNP
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-123303.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=64,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=8,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 24,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T13:56:16.260979Z",
"start_time": "2023-02-13T13:50:06.139878Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:03:53.237665Z",
"iopub.status.busy": "2023-02-22T16:03:53.237279Z",
"iopub.status.idle": "2023-02-22T16:08:06.306724Z",
"shell.execute_reply": "2023-02-22T16:08:06.304477Z",
"shell.execute_reply.started": "2023-02-22T16:03:53.237633Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [02:56<00:00, 5.66it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAESBEIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KPxR8ZeHPiPqulaTrP2exg8ny4vssL7d0KMeWQk8knrXH/8Lt+If/Qw/wDklb//ABuj42/8le13/t3/APREdc14T8N3XizxJaaRacNM2XY9EQfeb8BQB2ulfFH4r65cfZ9L1G4u5cZ2w6fA2P8AyHVa9+MHxM068ltLzWnguIm2yRvY24Kn0P7utPxN4/t/D2tQeGvDSzWvh/TJQlybSXypr2RfvM0gGcZGPfntjHm+v6zceIdfvtWugomu5TIwAAA9B+AxQB694i+KPjKx+HPgvVrbWNl7qX277XL9mhPmeXMFTgpgYBxwBnvXH/8AC7fiH/0MP/klb/8Axujxb/ySH4df9xL/ANHrV/4YfDqw8QafeeJfEFyYdGsGOY0ODMyjcwz2AGPc5oAqR/Gf4kTZ8rXHfHXbYQH/ANp1H/wuz4hg/wDIw/8Aklb/APxuvYPBnjG31nV5dM0jwHLYaZKpjF9DEsbKmMBm+Xv9TVLx18F9Cg8D3U+iW7rqVmGnMzyFmmHVg3bpyMDt7mgDnPDvxS8ZX3w48aatc6xvvtN+w/ZJfssI8vzJir8BMHIGOQcdq4//AIXb8Q/+hh/8krf/AON0eEv+SQ/EX/uG/wDpQ1cLaW0t7eQWsABmnkWNATjLMcD9TQB3X/C7fiH/ANDD/wCSVv8A/G6P+F2/EP8A6GH/AMkrf/43XqPgzwv4F0HxAPC0+my6zrxVTdTTWwkghO3JGTwMfT0rxv4labpWkeP9VsdGyLWKXBjIwI27qPYUAeh/C34peMvEfxH0rSdW1j7RYz+d5kX2WFN22F2HKoCOQDwa4/8A4Xb8Q/8AoYf/ACSt/wD43R8Ev+SvaH/28f8ApPJXn9AHoH/C7fiH/wBDD/5JW/8A8bo/4Xb8Q/8AoYf/ACSt/wD43WB4Q8LTeKtWNv5q21lAhmvLt/uQRDqx9/Qetdk+jQa54R1SXw1pkVlotkreXPcwLNd6g68nDEZQAcnbgDpQBmf8Lt+If/Qw/wDklb//ABuuw+KPxS8ZeHPiPq2k6TrH2eyg8nyovssL7d0KMeWQk8sT1rw+vQPjb/yV7Xf+3f8A9ER0AKvxr+IjMFXxASTwALK35/8AIddGvi344tpv9og3/wBk2eZ5v9nQY2+v3M1z/wAFrM3fjxikEM08NlNLAsqgjzABtIz3zXqXhvwx4+s9dtdf8SeJYyJgwuLCSU4II4UKMDP0HFAHlP8Awu34h/8AQw/+SVv/APG67DxF8UvGVj8OPBerW2s7L7Uvt32uX7LCfM8uYKnBTAwDjgDPevGtVgmtdXvILiLypo53V4/7pBORXa+Lf+SQ/Dr/ALiX/o9aAD/hdvxD/wChh/8AJK3/APjdX9K+KXxX1u5FvpepXF3MRnbFp8B/9p1wvhvRZvEniOw0e3wJLqUJknGB1J/IGvWPEvxKtvAITwt4CiijWyOy5vpEDmWQdcZ6855P4dKAOXufjH8SrO5ktrnXHimjYq8b2NuCpHYjy66vw78UfGV98OPGmrXOs777TfsP2SX7LCPL8yYq/ATByBjkHHasz4zaRPJpnhvxRdJGt9f2oS8K/LufAKnb9CfyFYXhL/kkPxF/7hv/AKUNQAf8Lt+If/Qw/wDklb//ABuj/hdnxD/6GH/ySt//AI3XAojSSLGilnYhVA6kmvaPDPwm0rQdF/4Sb4h3H2a2QbksM4Zj2DY5JP8AdFAFCy8ffGXUdKk1SzuLyaxjUs066dBtAAyTny61fhd8UfGXiP4j6VpOrax9osp/O8yL7LCm7bC7DlUBHIB611EfxLhv/hh4l1Oz0+Gz021jFjY2yHDjeCuTjgdQQB6GvKPgmSfi/oZPJP2j/wBESUAJ/wALt+If/Qw/+SVv/wDG6P8AhdvxD/6GH/ySt/8A43XAUlAHpsfxX+J8mjy6qNcxZRSiEyGztwC552j93ycc1S/4Xb8Q/wDoYf8AySt//jdZ3iKUweBPClki7UeOe5fn7zGQgH8q5CgD3D4pfFHxl4c+I+raTpOs/Z7GDyfLi+ywvt3Qox5ZCTySeTXIL8a/iK7BV18sx6AWVuSf/IdJ8bf+Sva7/wBu/wD6Tx1V8CeNbXwfb30kfh621HVpCptbicbhCB1+Xr+IIoA6jTfHHxq1eGSawkv544zhmXTIOD/37rGufjJ8SrO4e3udbeKZDh0exgBU/Ty6qap8VfG+sahG0ur3FviVWSC2/dKDngcckfUmtn47RH/hK9KupohHeXWlQy3SgdJMsCMfhigDd8RfFLxlY/DjwXq1trOy+1L7d9rl+ywnzPLmCpwUwMA44Az3rj/+F2/EP/oYf/JK3/8AjdHi3/kkPw6/7iX/AKPWuChhe4njhiXdJIwRR6knAoA73/hdvxD/AOhh/wDJK3/+N0v/AAuv4i7d39vnb6/YrfH/AKLr0S60PwH8NtF0211rw9catqdzD5k7+WW25GTz0GDwO/Wud1b4y6F/ZH9naL4G02JFOB9sjSRAP90Ac/U0AXfDvxS8ZX/w48aatc6z5l9pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/wDC7fiH/wBDD/5JW/8A8bqfw9c/bPhb8S7kwxQ+a+mt5cS7UXNw3AHYV57b2813cxW1vE0s0rBEjQZLMTgAUAd3/wALt+If/Qw/+SVv/wDG6P8AhdvxD/6GH/ySt/8A43V2z+B/iZrdbrVriw0m1Iyz3M3zKO/y/T3rrNY+HPgTw78KdQ1hbg6reqoWG9EzKDKSAAiggEc55zQBW+FvxS8ZeI/iPpWk6rrP2ixn87zIvssKbtsLsOVQEcqD1rj/APhdvxD/AOhh/wDJK3/+N0fBL/kr2hf9vH/pPJWH4X0XRL8td6/rken2MLgPFGN1xKP9hemPc9KANz/hdvxD/wChh/8AJK3/APjdH/C7fiH/ANDD/wCSVv8A/G6qxeHPDV9fS30/iOx0zTZHLx20W6aZE7KQTwcepNa2neEvAus6drZ0zVtZmudPsZLtJJYkji+UZG4YJwT70AVP+F2fEP8A6GH/AMkrf/43XX/FL4peMvDnxH1bSdK1j7PYweT5cX2WF9u6FGPLISeSTya8Qrv/AI2/8le1z/t3/wDREdAB/wALt+If/Qw/+SVv/wDG67rwf4g+J/irw/qGty+MYdNsbQHEtxp8BWQgZODsGO3NeS+C/C9z4w8UWmkW4O123TOP4Ix9412vxS8Ym+mTwX4diePRdKPkFYgSZ3Xgk4
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEIAAAESCAIAAAB2Fud+AAEAAElEQVR4Aez9d/yuS5bXdYtiQFBRQDIIkhVBcmhhnxmSIJKjIElAhpxBaOk+TYZhQHIOApJzlGG6z2kEUSQJkiWrgygiPoxiet57f/b5dnXdv72nkecPn9frXH/UXrVqpVpVtWpVXdf925/h//6//+9/6N3nXQ+864F3PfCuB971wLseeNcD73rgXQ+864F3PfD/Px74h///x9R3LX3XA+964F0PvOuBdz3wrgfe9cC7HnjXA+964F0PvPCAtzGPL2Q+w2f4DBor/+F/+CNHnUe81n/kH/lHXuNLBLimJQkJB++51KVU6ym56oWcdvhpwZ6Ei/1CquLqiRJ8AiOATBGCaYG86MOcJoU59YLP6iSgPJ+01Dr6EQ/A8io4aSdv8MxDcGoJj+YUmJCTJUw0Ex6yMmRyIkvLMOOdogGnnBO+FCUqArxVz3G56FGeKmqF2ROBajLP/p664E8tEVfijXJCkjmaE381YZzzNameWk5jJu31wGt0nYwbl5DjGsASTaqeejf2miqHjHhVXPMYGH4CT5pTOAIPTM/IQqqe+BM+PRbXP/qP/qMBr3cgIdPyeuAzfsbPmBbAKBkGriRqdkbweuFaEwXwJAeQljBTNGC6hjmBV3GdNGBC5rEEZrmSu04hJ3wJWVMStA4Dzu1pGUESIlNeuq7Wqk+WE0iIZ30Bn/QjO5Gj+fudGKcWkk85l6JVzw6OfsYgGyUggmkZ2T84kOTpInBapjctp5GPXI+WnDIfW6flsWmYvy+NJ/EknFqeJED5aOcwsYwR0COAAJokgFPLVMNP+AmPIICuHtXIRrBqwJNaEGOP4Cwn5O8L0KNpmXYSTrhqmHngIriU1qpkaixpGT76CUF2armkPVmNN+EjWF+GeT1ASHJeT6Y1CyNLy4kZnLSzHO/HqGiWnH35dHkZkCsAJzH4wkx+ZKcWTeOtR6uuaYomZ00XBm/sykvLyRLNeK/q8AGvb32p5VHZqU8HbLQJAldFMPdN5WhO9rWmJRq8ywNgVJU9qUg+TOyQYSYtIMYTmZYJIWEwsgSuPJtOIVFGluqr9UmPnTTxhpm6VxGM7KK8tCwni/41xp+KXgWvX5eWGTNGVo14SJizj48EpxytaYnlcdQmdgDKHhgAlprSC6M6pbVCri9rmsDXAEmb/JMyOZMW5bSclE/C6MeLIPapA2T5uRaSo+nScsoZzWSGmR+qKi+uq4rgHBfVbDu9/aSodaHWqwp5KUpLSMSbyScjOL3jBYxgJs1ITcWl6JWXx15lW/iPvTxV41pfsi3t4D1oQk4FCbUOr7rWgAszLWO56B+rl4RHgkdFabkYVV+j9FWtCblGFlLfNy6qnkfD8vDl55Ps5GLbWX0u8cUzLSfjYCSD/0GAaUlgZSaBZ1vAejQCqmcJpAem0mQOgEnLqqfBYw9Z9UKe9K+Bn9RC1MLRI+/rFWl9tHkee5T2Gswl5/V6yXmVlhg3EGmEDK8aXDWlazrNiywtI3gEsDx670kyJiWz8tT1qr6cNE/CX/yLf/HP83k+D8mv9172XFpmZJJn1YkHn548m8YFON14eewkC34sT/a1nrqCT7KrL5eWk1cTRpjYn5Q/JODknZaQZ9MlbYwn5Ul/qgA3GUI+egwj+Z5HCZCXKNWRvWB6yTXkqeXiTdpFOZoZOYJL+/CAnnlsQi4A2SlE9SJQfRJ5kr3UMmVjOEWPodarVD1n9kVcNWnTAjlFA9AMJjCWMMoAjJdhZzX40jJGQAQnSwJPzAlnfOXk1NlTS5gRjD5Rwz9KfrJpSMCphdg1nYYFPzkEHyPZpQUXU+lKXWYnP/w6MoKAS91z/mNSpgVvD+JaRwN/qVZ97Fdkz0UfwpOmaX25COJCFuVVvkB/VJHwk+tsnpbkXGSXYWNEdlHWhH7JXwLDz2OnhEeHnCyjfD0w5wDWlyfNvpDnDn01TaM+rmn9vbSMANdzp7zDcvYOMpkhsTz5RIOYD6dlxmB5Ek74Wj//5//8733ve3/zb/7Nv+bX/Jrv/J2/M5ZpH/uAtIwXZbAyIMqqZ4mSA4cZGSBkgBJlWl5jRuxneXrvxE94SDITC395DB5yvOAoK4efnIjHwgBwDxrV+ru+aIIfQXJO5AxbU61V4x0cQEtcV18uso+lOlGvIU4Lq+aTurO+w4NJyHJVT3B4Va0njWo0IVVf05dEoXn989k/+2f/Ht/je/ze3/t7f+7P/bn/1r/1b/3L//K/fNIn5NJyST6r68KjkBMDPrlqurTo4yPNKST/nJgLfpL91DIJKEccMkzIYMI1nVatdXrDqKYFfZgp0jSauF7TFAH6nqrKRh9jWrSu6QIuRq0wn/Nzfs7f+lt/6x/7Y3/sa32trwWTATPjkmZhzmM1oQREP+AScplxVqdoyIRMy/ABUzr8o4Q1DYhLFXDC0zLkWCKu+oLpZQg6CT5GOC2zc7oCVh3BKXatA9Z67qqQ60sE6BN4iYX3oKls8gQPOQkBJ/7SMoKAS9fVelWn9MKrvl7LI/2rMOz5dLW8/DoCHa0JAuD8v/6v/6sqeAZFVgn/f754HtWjn8CknTQwqggm/9QFjmAlSroQaxqgGkFClBMYZq1ZonoRICPQDGgS6MqEjPcSFU3INSWfYesF9iRUItAUWVwXHHISztZT16vgR6uepFwHs+qkodGT/esFgihhZhvAA7+m5PAhzMWb2NyLbK3wsVeeTVN6AuCeJIwrJHvgTyQ41YBao7z0xjLkZ/pMn+mLfJEv8k/8E//EH/pDf2hc6+wwL015Md+mlJBHYyZZ0/o+9gAS/vf//X8Pjn6iTq7MIGRmBIxllozg1H4qHeWTyGkfmeEzc+D/j//j/zhZgqNHnLq4QtblYMQRAFiYwCRgCTNewNwVJd5Ji6zoX2uYpJ3liT/hRMFk0tf4Gl/jm32zb/ZlvsyXwfvP/DP/zC/6Rb9o2idtXg0TI8+gJGcAMgQwnmiUnsiUYE3IVsLU/RiV0z4AMkbA43NK+NJf+ktLW01jp7L/8D/8DxFTdLKoerLzxIOpyzACVZUZPDPCY68pmoTAGAswCXGpnlrGgoaEeJ+b8gI+ywRGkGq8J0FGwmxmnixRDvMI/GP/2D/2uT/35/4CX+ALAD70oQ/9c//cP+cM/N/8N//Naf8j12zQNJglYIyes+PRjIxJWtHA1yOlCWxBoQkZAH7Vk7QnexevpmhI+Jpf82v+8B/+w53PP+7jPk6C+0//0//0p37qp/73//1/n/DsmaIsPyWHGUE2rwqYohMJPm24mqrmBPClYsSPutYUcNp5NbEqsQAPUcqzdzNvMwcBZEpPk16jJeLkM2Dyk5lJqQY/ykljeLxoYJTn9DgZaz3lgEMi64H5gT/wB379r//1CTG3ITOyUjWWCTkjeUiUdSTzEjt21eHtkt/gG3yDr/yVv/L/+r/+r3/qT/0pBydlQkaDnrRXjfXITvmzLdVnmWHKUwsh4U/GE56EIU/gy37ZL/ttvs23gWGDdfFH/sgf+YN/8A9+2qd92rheBZxC0GQDAJ5JAcr6eJaQnrFXVRYELvzJWNPpq3jDKxHn6gkZMI0JhJ/BM0ATWDkVAy6aU2wCJ3+UTwIjTtfkDD8umO///b//V/yKX/Fv/a2/9Z/+p//pW2+99Zf/8l+ePY/0Y3x5jCEaUQ+nALbUJwUPMvjsyHdf6sUjUP7tv/23/8v/8r/8s3/2zyY6gcqAkMQClGQOn3z+HT7ilS9kvNyMI0Z5IlHCkxBLrcqqKNszxlKTKgK9sI21sGdJTZM2OYCJPZELYamuaWVIMtc6+WzOjYjTPq7XAxM1Mrvys2fP/sV/8V/8p/6pf0rr3/t7f+/P/bk/Z33+9b/+16OBHPEFaLq0w5zItc7aL/ElvsQX+2Jf7HN8js9By3/yn/wn4fODMvq6vDSdUjIhPS0n1TwQHoGmwTMS5gt+wS/45b7cl/tsn+2zGaz/5X/5X37Lb/kt/9P/9D8lcGQDyHyh5LmoLEkmZDQRgCML6R7r4z/+44Wz//w//8+jPLnmgYj
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1090x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-123354.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=32,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 25,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:00:18.151514Z",
"start_time": "2023-02-13T13:57:07.698420Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:16:59.757373Z",
"iopub.status.busy": "2023-02-22T16:16:59.756510Z",
"iopub.status.idle": "2023-02-22T16:21:11.518330Z",
"shell.execute_reply": "2023-02-22T16:21:11.517233Z",
"shell.execute_reply.started": "2023-02-22T16:16:59.757332Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [02:55<00:00, 5.68it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIiAiIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPquk6TrH2exg8ny4vssL7d0KMeWQk8knrXIL8a/iK7hV18szHAAsrfJP/AH7pPjb/AMle13/t3/8AREdcHC8sc8bwFllVgUKdQe2KAPYJvFnxxt7Fr2b7elsq7jIdNgwB6/6uub/4Xb8Q/wDoYf8AySt//jddtpGqav4H8J3niXxfrF/carqUTR6fpk90zAqw/wBYyE4HX8vrXhRJZiT1PNAHt/iL4peM7D4c+C9VttY2X2pfbvtcv2WE+Z5cwVOCmBgHHAGe9cf/AMLt+If/AEMP/klb/wDxujxb/wAkh+HX/cT/APR61z3gzw6fFPiux0kuY4ZWLTSD+CNRlj+QoA7OD4m/Fm50a41iHUp3063YLLcCwg2qT7+XWb/wu34h/wDQw/8Aklb/APxutbVPiq1r4p+wafAv/CJWsb2Q09FVVmjI2lzx94kAg+3ua8tnMRuJTACIi52A9QueM/hQB7Z4d+KXjK/+HPjTVbnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/wDC7fiH/wBDD/5JW/8A8bo8Jf8AJIfiL/3Df/Shq4CgD0D/AIXb8Q/+hh/8krf/AON1M3xh+JqR+Y2sTLH/AHjp8AH5+XW58J/hRDrUUfiPxIBHpK/PBC7bRPju3+z/ADrutF8e2vjvxzq/g5hanw+9nJDbAR/M7DAyOew3EfSgDkvhd8UfGXiP4j6VpOq6z9osZ/O8yL7LCm7bC7DlUBHIB69q4/8A4Xb8Q/8AoYf/ACSt/wD43Vn4R2osvjhplqHVxDLdIGU5BxBLXn+nXh07Ura9EMM5t5VkEUy7kfBzhh3B7igDu4/jL8SZjiLW5HP+zYQH/wBp0S/GX4kwECXW5I89N9hAP/addb4Z8b+M/FD3cvhnS9E0W0tIw11Pb2i/KvfGfvH2q7puo6t420/VtI8e6ciWy20k1rq01r5LxFOc8fTOKAOB/wCF2/EP/oYf/JK3/wDjddh8Uvil4y8OfEbVdJ0nWfs9lB5PlxfZYX27oUY8shJ5JPXvXiLhVdgrblBwD6133xt/5K9rv/bv/wCiI6AD/hdvxD/6GH/ySt//AI3Vm8+LvxOsPI+16y8PnxCaLfY243oejD930rzau7+JNtGsXha9gbfBc6Jb4YdA65Vh+Y/WgCX/AIXb8Q/+hh/8krf/AON12HiL4peMrD4c+C9VttY2X2pfbvtcv2WE+Z5cwVOCmBgHHAGe9eH16B4t/wCSQ/Dr/uJf+j1oAP8AhdvxD/6GH/ySt/8A43R/wu34h/8AQw/+SVv/APG68/rs/hf4Ti8XeNLe0u0ZtPgUz3RHA2L2J7ZJH60AbmpfFL4raOls2o6pNbLdR+bAZLG3HmJ/eHydK6Hw78UvGV98OPGmrXOs777TfsP2SX7LCPL8yYq/ATByBjkHHauB+J/iCLxB42umtCP7PslW0s1VcBY07Y+pP6Ve8Jf8kh+Iv/cN/wDShqAD/hdvxD/6GH/ySt//AI3U9p8YfiZf3SW1prUk87nCxx2EDE/h5dedxRPNKkUSF5HYKqqMkk8ACvVdVuT8JNEtdK0zYviu+hEt/eY3NbRnpGmeh9TQBJrHjz4zaBardatPeWcDHaJJdOgC59M+XWp8Lfil4y8R/EfSdJ1bWftFjP53mRfZYU3bYXYcqgI5APBrA8J6xqt/4P8AGF74j1K8u9KNj5SfaZWkzcs6+XtDZGRz06VnfBL/AJK9oX/bx/6TyUAH/C7fiH/0MP8A5JW//wAbpy/Gr4iu4VNfLMeABZW5J/8AIdefV6v4Z0ibwd4CTxfa2a6lrGpEw2IWHzVswOsh/wBv0oAz7z4vfE3T5hDea08Mu0NsextwcHp/yzqv/wALt+If/Qw/+SVv/wDG6pQ/Dzxx4gdtTbSLyU3TmVribguScliT/M1l+JfD1v4bmSzbVobvUEOLmGBCUhPpvzhj+FAHqPxS+KPjLw58RtV0rStZ+z2MHk+XF9lhfbuhRjyyEnkk8muRT40/EaVwsevM7Hoq2MBP/oum/Gz/AJK9rv8A27/+iI63fglpcNnLq/jC+2Ja6XbsI5HzgSEeg69vzoAwz8a/iIrFW8QEEcEGyt//AI3Sf8Lt+If/AEMP/klb/wDxuuVuZNT8WeILi5jtfOvryRpWito+Mk9gO1R6xoOq+H7hLfVrCezmkTeqTLglfWgD17xF8UvGVj8OfBerW2s+Xe6l9u+1y/ZYT5nlzBU4KYGAccAZ71x//C7fiH/0MP8A5JW//wAbo8W/8kh+HX/cS/8AR61h+BdK07V/FtpBq9zHb6fHmadpDgMqjJX8aAPSLPxL8ar/AMJyeJbfUt2noCw/0S38x1HVlXy8ke/tXJ/8Lt+If/Qw/wDklb//ABut/SviXqWq/FnTBZKY9FWQ2UGnxABPIbg8Y6nAb8K4b4g6MNA8eavp6oFjSctGAc/K3I/nQB6V4d+KXjK++HHjTVrnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/8AC7fiH/0MP/klb/8Axujwl/ySH4i/9wz/ANKGrz+gD0D/AIXb8Q/+hh/8krf/AON0f8Lt+If/AEMP/klb/wDxusfwVB4emv5v7e0/Ur/Cf6PbWRwJX/useoHvmvRfEmj/AA60jwlfSX2iwadrrRk2llbalJPIhPCl8nbweSMUAP8Ahb8UvGXiP4j6VpOraz9osZ/O8yL7LCm7bC7DlUBHIB4Ncf8A8Lt+If8A0MP/AJJW/wD8bo+CX/JXtC/7eP8A0nkrz+gD0OP40fEeVtseus7ei2MBP/oulT4y/EqTOzW5HxwdthAf/adei/s+eHksdDv/ABNexQos5McEzkZCKcN16cjFd74s1SL4f+Hp9V0XwzDdIzNJO1vtjEeed7YHIz6UAfPR+NnxEUkHxAQR1Bsrf/43XX/FH4o+MvDnxG1bStK1n7PZQeT5UX2WF9u6FGPLISeSTye9eQa7rE2v63d6pPDDDJcyFzHCu1F9gK7D42f8le13/t3/APREdACf8Ls+If8A0MP/AJJW/wD8brWvviR8V9O0Sw1i41giwvwxgmFnbkEgkEH93weK8qr2P4iQyaf8FfA9p5qOrhpGPGeQGAH03YNAHOf8Ls+If/Qw/wDklb//ABuuv8RfFHxlYfDnwXqttrGy91L7d9rl+zQnzPLmCpwUwMAkcAZ714hXf+Lf+SQ/Dr/uJf8Ao9aAD/hdvxD/AOhh/wDJK3/+N0
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4Aez9d6B323bX9YvlZ++9YC9gwQ6CCM/JtQA2VMAQjAgiKAEliQZjiLn3XIyISsQYMJYYRLBgFLABQpJzjlFRsffee++Kwu/1PO99Pnfe+d3PvgcIv39+Z/0x91hjjvEZY/Yx51rftb/Dr/yVv/LX+fj6uAY+roGPa+DjGvi4Bn7t1MCv+2sH9mPUj2vg4xr4uAY+roGPa+BNDdjNtKH5Dt/hO2AsHXHWE+av++t+1so0sfHjnFDoWYn/rEyGlvXsbczZOmUoZqVctxG/3q/36xFz67rUZWESc0VPJhVp6mfuWRaKYSY/9ZM/gY9OwJmV3+A3+A0uxazMZ4QrB07JmCdn9PycFRzy8ZdObFA4v/6v/+uPP8CXidPKC5KPsDiuWX9BV1ZWyKNLk6fuGgfRJRc/mUnu9kOppz4z4ZUlSWIXMYSMPuLUl4jNjXEw0zrLMsCILE434ipIMjMx4TgDvMoy/q8hMeuZe9bK5fB1+6wDlSLwRuXEMC8rAE+Z8EOgNXOIXXLRJ36crEBzzQrJ+Aj8kCPKCmpisuLMZ0S2IgiMMysJT6zbX5M0HyrXZWUOVJbPaeUjevVkpT+rDuirjhGZfMRN6+THKT0dzUqSSy/ilB9NpvKPM0LWrphnWWRdbuAQa5aMTss8vttpxTlTbrh1PVXch37gfEi+/vtG5LM4y31bQSYwguenFZjPlmXyEcmcDuTbs3bLOtvlBLmQ3QI/kRMIZMJunxU7yzLhz0lc4OS13QtaszLFs0IwXXGAXLfBYiImc9GpZOXZYgYiXYWnEuxy30YkDJn6yjLhE2QeZqisMaeCOLWuW1mPVk7dby/6WSusd51e4WT0Iq7bOTY+zqxgVodnbrT07EJuXYRdYUZI8eOo4SoZEzEraJyYJBHSN3ifwSSTuawkECz5TLyG+JAOx+2soJO/0vhvy004/Etxt3JnZcyLeES4ON2OGbEa7vbJSn8u0cves+U5QQm4fVYsqMciXcJzgHz0yaml41BM122ceZuVhCcToMLjlIXzGuJNE8ZHy5JegCTjZ+KNwc9qnkf5JD9KeunmzxTXLvEnnA8Te5YgU0mndYkN5LFdTpXEJhyIWy5dXg3/FB6dlVRIjo8Y5uhBIV6wMrTJn2XJBJmubqVupascdFYi4gdIcsjJSAmsXah0hZkMzrRCODHRj0wqp0zqZ1nmyQmeWA7M4kk8Cl+5ev5p5cz9dqQ5/2jlKm+ung5X5BWcP9GnJPoUyArOcEZgdsHBXCucxUwgFTTCFUGMilvEyvIm/2kgjE7mxMdZYdEAgZy25kO65WYlTir4w5nKIzEQWdEnJ+b4K8uEB0jGNdPjP0o+m0U3PuLJymlsOqeB6Sz3bQTJrgROkGetECY5sdUjoqzf7rf77U5bMXFGnOqYWRngqTv61E39lEcT6ErlzMV5dnCe8jP0SBCL+TnlGa0sqZzpYM/qGnP4F6fbynK68dgupy30eXtizvrJfBv9aOVtkr+q/NxL621WzlKvNU/FGX22sKckOisnzilQ18UhID0JVuLM3EkMMLGVJRVp/FS6PdV/+9/+t//yL//yn/gTf+Jf8pf8JehT+ESeSsxZGf/XBjErpyfPlut0+7GM+fYsH3PtgjZIpcyVUry6Kz7O7/K7/C5f9EVf9Pv9fr9fyEWcssrFzGG30VRWllQyUe4bvacBjpPWCLqEu9CZkOIkicmBsi4r2crK0iEv91li+I/yp5V5+zIIsa7ETpqh6eKPzsrzpxC/4lf8ilNu9G/4G/6Gf+Af+AdqmN/qt/qt0CD+p//pf/p3/p1/59/+t//t/+F/+B/cThIRyOnKcjEznxgXCf+//+//m0DEd//u3/0H/sAf+Mt+2S/7h//hf5iV6SJOQ7MyJk5lzgR6tuhmGqFFGTqzMENDdK3u0srPDzNf/x1n1s/ci57MiEtgt6cbz5rAXHUl/Jv/5r/5H/PH/DHSf+wf+8f+j//j/whqhe22mhm43FlEdMu3LFY/8R8dnvXf/Xf/3V+9evWb/Ca/yb/0L/1L/9Q/9U8FOJ8vB05zo5M507J+x9/xd/xdf9ff9Tt/5+/82/w2v83/8//8Pwqlm/0r/8q/8r/8L//LZWVQiOEgrvKeXiVG3lzzXb/rd4X87/67/+7//X//3zirgclgqg18F9oFefUTEzi+akFQXP2UawZRhGi5JIklM69e4765ZmUOpPJh/lP/n+Lv8Dv8Dj/6R//ov/Qv/Ut/09/0N/1v/pv/htjf9rf9bf/b//a/Jb/mnjriWeYp8GuDXrmArypWxixOZsTlycmf7pgITFWdiapISylvdSh11Ux6l1X5e37P7/k3/81/8y/9pb/0//w//09aVNIlhpDirDVxuvDZWjWOiD9dwrk0BFmEL8xu61HSD43cf4OaLbckfqff6XdShO/4Hb9jHQzUf/ff/XemzX/z3/w35WYrydJA83AGZFGUuiYwGif65IQwTobyLeYGCPV7mbmU3Q7IUPle3+t7/eV/+V/+Xb7Ld9k+w8j8Jb/kl/w9f8/f87N+1s9aVc77txHDJKB4q7jkf6Pf6Df64/64P+7H//gf/4lPfOI//8//c2uMlUYWrdMftJKsS122TsyZqypLkx/m7/P7/D6/5+/5e1o75Srpf/Vf/Vf/yX/yn/yn/+l/SmxGhzPmWa1zAEiBifFvrvz/vLnkEv7f//f//b//7//7//K//C+tygOZ4iOHRdccIFB1xXT7Jv91z7DACM1Unbl+ywz+s9cJOAFQ0RFna5KXNa3T+h/yh/whf+1f+9cqqWnun/ln/pm0VvnDPNUrwqNdnOQZEsp84Rd+oYXz8z7v88gnrGhmhJ/5M39mYrMyx4gR5gOO3NJ0qbh0GPxoUf8f9of9YX/in/gn/ma/2W9m6fqP/qP/qGWGfMjEILgQlQsdWlak8yFinN/6t/6tf4/f4/dg4r/+r//rf+1f+9fqpblHck7mCX62lmZlt4hkpPmzW1Df5/t8ny/7si/7jX/j39hgsWp+yZd8iYnGeEzrVEG34CH+f3zxk+eqcbNPdbVb/hg4v9vv9ruZMc34muO/+C/+C71a17J8Kp16kxqVxqZxRH4FPMsyZsgZJYDIYi0lFb4IkvQ0tv6tf+vfMiqJ1TpAXFRw6k4IHKkrnFIy45f7yJlAgNJ8GE6YqZeFPnFSHE6SUs7/sDeXxWbM/+v/+r+Mkb/hb/gbdLxMnGjzYfIRl1225oCs0dM6nSl3CBltyDzJk04h80N5JIwcUz9h6+S/+q/+q8rwb/wb/0abmJ/3837eFp5HxZx42cqsi4u/4Au+4N/79/498ppf+g/9Q//QH/lH/pHPwipPfISLsAsntJhSty7d7hSeGEJv/lv+lr/ll//yXx6CVFT41V/91SYguRBSTCWZcUaY4v/oP/qPtjT+gB/wA37Ej/gRpl2rr/lrmMH+gl/wC770S7/0O32n7zTFiNMKzqzkfKVYWS5dt+aXf/Qf/Uf/1r/1b70aIvVHeRxZs+J2JsqaCrHRj4TGMqmZTAUfj7lZP608K3MyqajGn//zf75ea73UEPqbBUyA5vbf//f//T/gD/gDkl+N5WFWYuJo7vgneLmmWuPzr/gr/grzMi3z2mSqBAJJnikZmCsL8K5nDb169erbvu3bCPNczXjNJCgq2XpWS1Zis5LwqQuh27L0tw8++IC8JvgL/oK/QP2g33///T/+j//jh/bk6Iemx7+sBPjtnmZlBQ//uo2pov7uv/vvJm/U/DV/zV9j6vyz/qw/62/8G/9GxcFRQHHAN33TNznk+C1+i9/i8nNlUTnAq96sSFdj+CkarT/oB/2ghjwTv/fv/Xvjp5swrS4qOF2zEkhZ9RbCY86EUOOP+qP+KJ3NYmlmy0S50qkggirrtJLdqZyEI6W/6q/6qxI2WJSlIWNOtum3UcvErJy66BThXwI4SY44FdX87/w7/87OGH7/3//3F5e7fsvf8rccwqPKU1nOIoE75aaMH22b+ef+uX+uSTmOGe1n/+yfDUG
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-123655.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 26,
2024-04-08 11:37:01 +02:00
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:30:05.578985Z",
"iopub.status.busy": "2023-02-22T16:30:05.578388Z",
"iopub.status.idle": "2023-02-22T16:34:16.926424Z",
"shell.execute_reply": "2023-02-22T16:34:16.925434Z",
"shell.execute_reply.started": "2023-02-22T16:30:05.578945Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [02:56<00:00, 5.68it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIiAiIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxR8ZeHPiNqulaVrH2exg8ny4vs0L7d0KMeWQnqSeTXH/8Ls+If/Qw/wDklb//ABuj42/8le13/t3/APSeOsDwNo9vr3jfSNLu/wDj3uLlVkGcZXqR+lAHS/8AC4/iUIvNOtS+XjO/7BBj8/LqH/hdvxD/AOhh/wDJK3/+N16D4j+JJ/4Ty48EDTrZfDYxZPD5W1vu9R6AHp7CvBbuA2t7NAVI8t2XBGDwaAPafEXxR8ZWHw58F6rbaxsvtS+3fa5fssJ8zy5gqcFMDAOOAM964/8A4Xb8Q/8AoYf/ACSt/wD43R4t/wCSQ/Dr/uJf+lC1wGCTigDv/wDhdvxD/wChh/8AJK3/APjdH/C7fiH/ANDD/wCSVv8A/G6f4K+D/iDxfGt3IBp2mnn7ROvLD/ZXv9a9A/4U58P9Dsnl13xQzH7u8ToiqfoATQBmeHfil4yvvhx401a51jffab9h+yS/ZoR5fmTFX4CYOQMcg47Vx/8Awu34h/8AQw/+SVv/APG6v6bb2Fp8O/ijb6XdNdWMcmnLDOwwXX7Q3NeYUAexr4v+N7orob9kdQysNNgIIPQ/6us+9+Jfxd02PzL29u4Exnc+mwAfn5dcnpXjrxjZSwwadruoZ2iGKISlxjsApyK9p1rxLqfhj4T31l451EXuualE0VtaLtDIrDALFQOmcnPoB3oA5/4W/FLxl4j+I2k6TqusfaLGfzvMi+ywpu2wuw5VARyoPWuP/wCF2/EP/oYf/JK3/wDjdHwS/wCSvaF/28f+iJK4FfvD60Aek2nxf+IlxfwW0/iZLRJSMyz2UCqoPc/u+lWdU+J/xN0iOFrnxLbkzKGRY7e3Y4PfAj4pnx1jeDxnaQsw2x2ESqgUAJheQMdfxrzS3tLi7cpbQSzMBnEaFj+lAHdf8Lt+If8A0MP/AJJW/wD8brsPil8UvGXhz4jarpWlaz9nsYPJ8uL7NC+3dCjHlkJPJJ614k6PE5SRGRxwVYYIrvfjb/yV7Xf+3f8A9J46AD/hdvxD/wChh/8AJK3/APjdH/C7fiH/ANDD/wCSVv8A/G6d8J/h6PHOuO96JF0q0w05XI8w9kB7V6lqvwz+E8usx6TDf/YdRkOxLaC7Z2LH2bdQB5X/AMLt+If/AEMP/klb/wDxuuw8RfFLxlYfDjwXqttrGy+1L7d9rl+ywnzPLmCpwUwMA44Az3ryzxloUXhnxfqWjQXDTxWsuxJGHJGAecdxnFdH4t/5JD8Ov+4l/wCj1oAP+F2/EP8A6GH/AMkrf/43W7J45+NEOkf2rJNerYbd3nnToNuPX/V9K4r4eQ6PP470tNeZBp5kPmCT7jHBwG9s4r3q0t/iHcfEm0S7a3l8JMJFCWgXyDBtwFPfPI/I4oA8d/4Xb8Q/+hh/8krf/wCN12Hh34peMr74ceNNWudZ332m/Yfskv2WEeX5kxV+AmDkDHIOO1edfEbSdF0Txre2OhXZuLRDlvSNyTuQHuBWr4S/5JD8Rf8AuG/+j2oAP+F2/EP/AKGH/wAkrf8A+N1ueEPiN8TPF/iS20e28SrC02S0zWMBWNQCSx+T2ryOvUPhy8HhzwT4p8V3SN5htzp9mR1MkgPI+nBNAEF/8Y/iDZahcWqeKEnWGRkEqWVvh8HGR8ldP8Lfil4y8R/EfSdK1bWftFjP53mRfZYU3bYXYcqgI5APWvECSSSeSa7/AOCX/JXtC/7eP/RElADo/jR8R5X2R68zt6LYwE/+i6T/AIXb8Q/+hh/8krf/AON1f+ApT/hZCK0W8tay4Yjhfl7/AF6VyPjjw9c+F/GF/pd0Yi6v5imL7u1uV+nB6UAdDH8aPiPM4SLXWdz0VbGAn/0XTD8bPiGP+Zg/8krf/wCN1ufCC90zwzoPiLxXfJC09qggtQ3Ll2HAX0z0zXmeq6TqelTRnU7Ka1e4QTRiVNu9TyCKAPYPij8UvGXhz4j6tpOlaz9nsYPJ8qL7LC+3dCjHlkJPJJ5Ncf8A8Lt+If8A0MP/AJJW/wD8bo+Nv/JXtc/7d/8A0njrgKAPQD8a/iKuN2vkZGRmyt+f/IdJ/wALt+If/Qw/+SVv/wDG6X4oMjweEHhjhSFtCgxsXBLAkMW98g15/QB7f4i+KPjOw+HHgvVrbWdl7qX277XL9lhPmeXMFTgpgYBxwBnvXH/8Lt+If/Qw/wDklb//ABujxb/ySH4df9xP/wBHrXO+FfCOqeL9Ra106NQkY3zzyHbHCvqxoA6L/hdvxD/6GH/ySt//AI3R/wALt+If/Qw/+SVv/wDG60nvPh74EdLe3sF8V6oo/e3UshW2RvRU/iH19K0fH9lpfiP4V6N4xsdKsdHuFnaGaC2hEay57jHXG3v6mgC74d+KPjK++HHjTVbnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/8AC7fiH/0MP/klb/8Axujwl/ySH4i/9w3/ANHtXn9AHoSfGn4jyttj15nb0WxgJ/8ARdI3xr+IqMVbXyGHUGytxj/yHVLw58SdS8K6UlnpemaTHMjlvtr2oeds9ixPStT4mvFf6L4Z1y6sYrLWtRgd7pIYwiyKCNrkepzQB1Hwt+KXjLxH8R9K0nVtZ+0WM/neZF9lhTdthdhyqAjkA8GuP/4Xb8Q/+hh/8krf/wCN0fBL/kr2hf8Abx/6TyV5/QB6APjZ8RCcDxBz/wBeVv8A/G6VvjX8RUYq2vlSOoNlb/8AxuuX8J2P9p+L9Hsd237ReRRZ9MsBXQfF+2S1+J+sRxurKXV/lAAGVBxx6UAT/wDC7fiH/wBDD/5JW/8A8brsPil8UvGXhz4j6rpOk6z9nsYPJ8uL7LC+3dCjHlkJPJJ5NeH16B8bf+Sva7/27/8ApPHQAf8AC7fiH/0MP/klb/8Axurdl8WfilqRcWWqz3HlqXby9PgbAHU/6um6R4E0rQvDlp4p8bTXCWdyQ1nYWw+e4GAfmP8ACDXovw/8X3OqPc3mnWNn4Z8G6WS83lRAvcNj7rM2cnHXGDyOaAPMm+NfxFRira+QR1Bsrcf+066/xF8UvGVj8OPBerW2s7L7Uvt32uX7LCfM8uYKnBTAwDjgDPevMPHGvweJvGOo6ta26wW00g8pFQL8oAAJA7nGT9a3/Fv/ACSH4df9xL/0etAB/wALt+If/Qw/+SVv/wDG6cvxq+IzkbdeZsntYwH/ANp1heA/C03i/wAWWumoP3C/vrlieFiUjcf1A/GvVNQ+LvhHwzI2me
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4Aez9d8Dv2XrXdT8g9l6xt8eOFRsqyuwcLGg0BhASiGioQRFFY4Akx8yZQYiUAAKBCFZUIhhQxAoczsyoYO9i771jFwvPa+/3PZ9ZZ/3uvc8B4l/PrD/Wfa1rfa6yevl+7+/v2/y6X/fr/j8fh49r4OMa+LgGPq6Bj2vg/50a+Lb/76j9WOvHNfBxDXxcAx/XwMc18KoGnGbOA823/bbf9tt8m28jp/iqJEyAMQNPZMQpixayAiB8u2/37cQpkYW4kjjxw7yUeeXVBZ4Gub/Jb/KbiLMS/1Qil9Hf9Df9TZf160ucprNCJyWMnqpy4+REX7AclpVjCBwBMUOrsTjiiMDR44yQe4UTOdiIlSUHTjDOYPlGc6VGLBdmueHLOmWzMg5i9KO3y3qFekKOmelLqmRWzqxpOMX1hPHz/MzFKZzMSh1nVlICPCRCkgNxxj8xcsMEG35E/s9KSaqmLQ58IumJ+Wyc7CWe1KOVZzV8/sxHK2QvK5l+Vuez4iGXNWJ64jxamYlTZMyzAsf8nMQbrJA9DZ0jZWox12on/nImK9OQOOWn7JhTHhHm9ARfsnBinrUSAPgiSr6OOfwJi36qsf409qblKk9aYAAUXgggxhQQFUOMBjjtYWYlqTMrhcVZSQNk4DEjJrvkCEZPK+MnIjmv0AK++GROeVkTfAV/ijCfKu5Ev6LTebqNHXNW4sRMQfTiZWVlgvGvZBpOnZf1soJdWRTirMbKhRwxT2ZUViGFAMMkKAlQ7jCIs8YSOQUBJIUZQp9KRr9CfZR1wU4rE8mfBNGnCcyS4tPtaPH0R8S/rJyCFUQcM+WnErQgN0BOxkEbR3HEl5XxBx4nAh8xtSXFV9FOKeDLSlInJnpqH7PizNyzgMvKKYJ+g3Jq0zy1SyKq3mVl5dL2Uv6zawY+zuKI8adwqgDQwc6yBJie4dOw5Aj80+fZncURp5UxEZfIOhgTZ1YWcU7Tp55UPVqh8NRziVwOQAYufgTHebLyaOxZAboqzEvdrwJYA0PqEjmLhwY7rcCnKqlTPHriaxWc19XpTJOdlfScms8xfPInfhEX5lR4WuGYQHZ4BM6Ssn6r3+q3+u7f/buvLBna0epEnj7MyphDZmL8iFN/LuGPmOyqMalZWSly/hKcckTWxcLpwzTELFeMn5WXAofI6BGntt/sN/vNvsf3+B5/99/9d//QH/pDf/ff/XdfFgde1+tWloGZFpZEvPTglQ9XPZyYyhj4FAlzWUkhK6k9pdBZXy7Pg01qWRNnHfO0sqwReVJ8MbM4gFyhwmKiJ4V+1spk30zM0JQPf1rBPK0MEzHkxf8NSD5aOZWf9JRfTMmLM+SIRytlJbg6KXnGYKdyyDN3ybS9wcqpBC1U/whKRkuejkU/a+Wx+QaOEL8086HCMU+C5mDFp87K8tEGiiIsiP/7//6/TxVxZNGVzDD/1//1f0GOnxQ9/8//8/+cGk6FuXvpAc56UnIjEpzCYRB8yHpInIhTQ3rKAp749J8iF/0s5mJeyct05n7L3/K3/FP/1D/1L/6L/+L/6r/6r37b3/a3xTTX/Jf/5X/5r/1r/9r/+X/+n9VkjbRKA1jRcp5mtsCqELRwOlwTjEMVjljIRHrE04Av9xSZlfHhMzRzE58gIhhMxKRGBJ6SKznB+HlumfmL/qK/6Ht/7++tAv/xf/wf/y/+i/9iJUqPOE5SZzy7CgJzqs3WWSflWvU1B5rseh1ZNE4VMrWzlXI6hTEjcBKcOL5m1bi0pTBMMVWIshApyeKSETGLwZabSJppm4kwlyrJ8IunZ5qXFXiA+OlHV12ICQ45YqoiHpHjXMgruXJd/DM5VbM+zglDA5xZ4XF0jwYgOpGpKln/ObUFWJ2cWaOnn9qQWV9yyIjlJljZxULdMn70qSR+StDwzAUQ459j//TkdOAqCyWpPeM8JJX+ZV2yAE/LzAwHzRs07bOdJcx1rLIey3kWBl4IObemH79cHPqpEs/p4ROfHsTVD5INJp5gRJ6X+yf/yX/y7/f7/X7mlF/za34Nc7/D7/A7/M//8/9sUvuv/+v/Oh+sB//L//K/SMKfFuWenJmAyef0D5OsrP/j//g//sl/8p/8/X//3//P/DP/zP/9f//ff/Wv/tX/2X/2n+V/bkMGboGRla1UwUhyVcBZcn1drhqTzAEAITBORCayEmwAycpCih4BBxIzvhjzd/1df9c/4A/4A/7AP/AP/J1+p98J8t/9d//df/Ff/Bf/8//8Pw+czvSnJz76NBpMnOYRM4TDlrL8IX/IH/In/Al/Atl//V//12uLHJtgSPHMoa9knmNaRVZeOtHVXu6B1RyQj96OM+KlmVcBJ+ZZTPTqHCrAk8CrlWZ0BE8UTYCsdGvKCykXZ7DlJrUkgjZxfHHJOLmHk7akMIeZw1M7TuArDrYamNQFO5MpPJFvNjHZORmHBsQpexZkUgHKugCT/aP+qD/KVKBL/0//0//UAJRVblYub8/kDCFOfuLFw7xU+uF8GLHk63wjK2tgRFbqOZICzLan5f42v81v8wVf8AVuAv62v+1vawsFE3JjYRz8ed7QkBVHLJwiknKFtJ20rDFfin2IfFpm5Am4NWR0ikbLEmD+6D/6j7YxV3KxOfrf+Xf+nf/0P/1Pg4lJpYdnkmD5EX8Ahdl8ipnm04FJISq53BQCLxcHny2YkxmdufDRqv4H/sAf+Nv/9r+9KZLU7/17/97/9r/9b2uMf/lf/pdtnKn6j/6j/0iWafTf/Df/zX/un/vnfu2v/bUJitfYMxSRfrKSv9vv9rupHwvJf/gf/ocq53/9X//XD16FP+VP+VN+r9/r98L59Kc/baWBhCeYzgrYZBc9E2dhMRNB1PCU5F7JAJcIE2ADoIWkshJN6sxCZytt3+E7fIe33377j/vj/rjf7rf77Uhp9L/r7/q7/o6/4+/49/69fy8lrGRCjE6qOACFiGyJ82Gml4X4rX/r3/qtt976fX/f3/c/+A/+g1/wC36Bo0waxClMNuZUlRxAUta3//bf/o/8I/9Iw+yf/qf/aWv8BHMVOI44cXPNH//H//Ec+MxnPmPGiV+cq8Eq4DjpwcSh+eyuaO1uH2OM/G//2/+W+GKawxMU8HEqAtrEsQlCMkCY0ZJMbJkkvqyI4vjoCn5qO73FL0zPCPzf/Df/zX/n3/l3VqW/4+/4O/6P/+P/aODYq9EJM58z96Gaj/6uS2DNGfT4MSUx6TcwDcP/5r/5b9ZG0zWXZispgLkROOXFZV0A+w8LzO/yu/wuX/ZlX6a//X1/39/3j/wj/4hB+p/8J/9JGh5tPVoJKQ48kfiqS7urdlVnwlQ0vUt7qUAdTKyYFeSxpDTI2vohWVmYQMOvBkqKBUx7we/5Pb/nX/aX/WU6v5tnk1hZxauuk0lKT6Dz7CEzdCEldQZb59/j9/g99AFJQ/W//+//+/ypBlgp+VIWJTAgL/OIgtwRL6GvCuk08Kt+1a9ScRWSjb/ur/vrVFwAMZFotYNOZ1bSgBNzSATwNJzEicHPzwGWW9asDPBI/Hl/3p+nJxnz6kWsmTW5A4c5KPHFf+/f+/fqIo8anrVSicwmn/rUp/Sef/gf/odV1KPsOHkujsCPpidiVgIMdmpQG0s+EquryY7InHhWJl5DnEiFMjbeeeedwLrsf/vf/rcqzQr6yU9+0siZtlOqCinr0UpImBM2HyzVVhcmvvEbv/F8MDPAI0HPoxVl+fIv/3J8V5R/6B/6h+bMLKofNE/GsQGE/+/+u/9Of/jD/rA/bFbOcs3KmImnaiIjrDE/5+f8nL/n7/l7rNA5IF7D8ZBgnswNgFlBTz/6DHPgZFJo0vw9f8/f05H9//sqmAhOADpB8Wkl5unDpGTZW5iOPS37GT/jZ/xT/9Q/5bhvW+axmZZah0nDpEY8WlkWglSCM20y8SzTweKH//AfbtU/wY/0pLIybY/IrBTLtadUUX/un/vnMqSnJW6G/WW/7Jf96B/9o5XrwqdwZSk369MJM1qb0v9n/Bl/xrv
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-123956.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 27,
2024-04-08 11:37:01 +02:00
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:34:16.929146Z",
"iopub.status.busy": "2023-02-22T16:34:16.928274Z",
"iopub.status.idle": "2023-02-22T16:42:15.960070Z",
"shell.execute_reply": "2023-02-22T16:42:15.959235Z",
"shell.execute_reply.started": "2023-02-22T16:34:16.929115Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [05:40<00:00, 2.93it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIiBEIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxR8Z+HPiPquk6TrH2exg8ny4vssL7d0KMeWQk8knrXH/8Lt+If/Qw/wDklb//ABuj42/8le13/t3/APSeOuCiieeZIo1LO7BVUdST0oA9Vj+JHxUk8Jy+JF1+P+z4roWrE2tuH3kA8Dy+nPWsf/hdvxD/AOhh/wDJK3/+N1e+ITL4c8HaB4GSbzr20DXd/sxiOSTkRnHcZ/WvMaAPcPEXxS8Z2Pw48F6tbaxsvtS+3fa5fssJ8zy5gqcFMDAOOAM964//AIXb8Q/+hh/8krf/AON0eLf+SQ/Dr/uJf+j1rgMZPFAHf/8AC7fiH/0MP/klb/8AxunL8afiM2duvMdoycWMHA/791P4e8Had4d06HxJ43hkMMuDp+kocTXje46hen1zXeeMPFn/AAiXguJ30XSrLXNZi2xWkVog+ywc/fOCWY8dePbigDK8O/FLxlf/AA48aarc6z5l9pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/8Lt+If/Qw/wDklb//ABul8KMX+EnxHY4yTppOBgf8fDdq4zRtLn1zWbPS7YqJ7qVYkLdAScc0Adl/wu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N10+st8OvhuDpK6QPEWtIMXMs8hEcT+mBx17Dn3rnde+Jun6ro81lZ+CtCsZplKG5S2Uug9V44PvQB1fwt+KXjLxH8R9K0nVtZ+0WM/neZF9lhTdthdhyqAjkA9a4//AIXb8Q/+hh/8krf/AON0fBL/AJK9oX/bx/6IkrgKAO//AOF2/EP/AKGH/wAkrf8A+N1v6X42+NWtW7XGnSX1zEpwXTTYMZ/791i+Bta8HeHPD93qGoaaNX8Qs+y2s54t0SjPB9D617V4T17xKujXvi/xvcNp2nwIzQafFCFGzHU5+YnsBmgDxOX40fEeCVopddaORThlaxgBH4eXXW/FH4o+MvDnxG1bStJ1n7PYweT5cX2WF9u6FGPLISeWJ5NeYeNvEx8X+Lb3Wvs626zsAkY7KBgZ98Dmt/42/wDJXtd/7d//AEnjoAuWHxY+KWqef9h1aW48iMyy+XYW52KOpP7vpVP/AIXb8Q/+hh/8krf/AON11PwjsRovgPxV4tuJFRTA1tCrcZIHJz3yWx+FeLgEkADJPpQB3/8Awu34h/8AQw/+SVv/APG67DxF8UfGVh8OPBeq22sbL7Uvt32uX7LCfM8uYKnBTAwDjgDPevEXR43KOjIw6hhg133i3/kkPw6/7iX/AKPWgA/4Xb8Q/wDoYf8AySt//jdH/C7fiH/0MP8A5JW//wAbqx4HsPhpLoklz4pv7xdQVuYAcLjIxtxySfrXZ2niT4RaheR+HLTwy32e9kWL7UYyHVjwCGyWHOOhoA4T/hdvxD/6GH/ySt//AI3XYeHfij4yvvhx401W51nffab9h+yS/ZYR5fmTFX4CYOQAOQcdq4H4meC08DeLX02CfzbaWMTwZOWVCSMN75Bq94S/5JD8Rf8AuG/+j2oAP+F2/EP/AKGH/wAkrf8A+N1f0n4p/FbXbwWel6nPd3BBIjhsICcDn/nnXmNeu6b8U9L8LeBbXSvCOkmLXLiMLdXUiZw54yO7H07D0oA0vEfiD41eFdEh1bVdWjjtpGCnbbW7NGT0DDy+KPhd8UfGXiP4j6TpWq6z9osZ/O8yL7LCm7bC7DlUBHIB61qfE/UdR0r4KaTpOt3X2nV790MpcgsAPnz+GAPxrzv4Jf8AJXtC/wC3j/0RJQAf8Lt+If8A0MP/AJJW/wD8bpR8bPiIxAHiAknoBZW/P/kOvPq9R+HXhnTdM0o+PPEpJsbWTFjZj711MOgA7jNABqnxO+LOix20mp6jPaLcpvhMthbrvX1H7us7/hdvxD/6GH/ySt//AI3Xb3fgqLxLqVpqvxH8RvZarqzqlnp8AUNGp4VSMccYH8ya8e8SaDdeGfEN5pF4P3ttIVz/AHh2b8Rg0AetfFL4peMvDnxG1bStJ1j7PYweT5cX2WF9u6FGPLISeSTyax9L+Ifxe1vTbq/03UJ7q2tSFmeOwtztJ/7Z1j/G3/kr2u/9u/8A6Tx1X8AfE7VvAUjQ2yR3GnSy+ZNbOMFjgDKt2OAKALT/ABq+I0cjRvr5V1JDKbK3BBH/AGzpv/C7fiH/ANDD/wCSVv8A/G66D4taZpWteG9G8faTaC0OpkrdRAjG/wBen3shgT3xXkFAHt/iL4peMrH4ceC9WttZ2X2pfbvtcv2WE+Z5cwVOCmBgHHAGe9cf/wALt+If/Qw/+SVv/wDG6PFv/JIfh1/3Ev8A0etcAAScAZNAHf8A/C7fiH/0MP8A5JW//wAbq4/xX+Kcemxak2qTiylZlSf+z4NhK4yM+X2yPzrpfB/wk8P2HhxNd8e3YtkuQPIhaby1RSMgkjkt7V2mr/2Bc/AzW9M8JXa3djZQmIMWzyHDnJ78GgDivDvxS8ZX3w48aatc6zvvtN+w/ZJfssI8vzJir8BMHIGOQcdq4/8A4Xb8Q/8AoYf/ACSt/wD43R4S/wCSQ/EX/uGf+lDVwcMUlxNHDEpaSRgiqO5JwBQB3n/C7fiH/wBDD/5JW/8A8bpy/Gj4jsjOuusVX7zCxgIH1/d13l/4Wbw9f2+keGvh3aatJFZpLdXt/vkDuw6D5gAevH8q4Xxv4o8QWtpL4cuPDlp4bt5isk1vawGMzAdNzZORnmgDrPhd8UvGXiP4j6TpWraz9osZ/O8yL7LCm7bC7DlUBHKg8GuP/wCF2/EP/oYf/JK3/wDjdHwS/wCSvaF/28f+k8lef0Aehr8Z/iQ67l112X1FjBj/ANF0SfGf4kQkCXXXQkZG6xgGf/IdT/D/AOIOv6ZYReG9I8P6dqkjyEx+fAWYZOSCQRkZ9a7347z2y+CtHg1ZLRfELlWxbx4CgD5sE5IXPbNAHnH/AAu34h/9DD/5JW//AMbrsPil8UvGXhz4j6rpOk6z9nsYPJ8uL7LC+3dCjHlkJPJJ614fXoHxt/5K9rv/AG7/APoiOgBy/Gr4iuwRdfLMeABZW5J/8h1rS/EL4xwWJvZbi9S2A3GRtNgxj/v3VT4Tpb2tj4m12KyS71jSrNZrJJV3qCSQTtHJIxUOl/GDxrF4jt5r/VHuYjKEmtZY1CMpPIwAMUAQf8Lt+If/AEMP/klb/wDxuuw8RfFHxlYfDnwXqttrOy+1L7d9rl+ywnzPLmCpwUwMA44Az3rzv4l6XbaN8RdasrKPy7
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEIAAAIiCAIAAABDhPpOAAEAAElEQVR4Aez9d9iuy5bX9R4UI0FAJIigBBEJLUkJEtbsJkiOkkOTk+QM3fTutQkCrU3OOUrOIKnZa22C5CA5iCAISBBEQDCdz5zfuX67dj3vnHvTnHNdel3r/qPmqFEj1aiqUaPqvp93for/+//+v/8/7z7veuBdD7zrgXc98K4H3vXAux541wPveuBdD7zrgf/3eOCf+3+Pqe9a+q4H3vXAux541wPveuBdD7zrgXc98K4H3vXAux544QFvYzz/3D/3oc8zn+JTfIp//p//55V5DjB47AMu755aCLlaX8U1+aN/ofOlAUMGEHJpeWQfyzRGU7VyXAPGBYBMy5CPZDB7RvYkgGyWRDBL0oIgPKddlOFHMPkXBlejNvyn/JSfEgzvmcfWOjlaN1JrjXHVEV+Yq/ohPTY5bKO36oSEWTUAchj04EtLyImKPpZJqPWkHDynTUveOMclzLwU71lN/npU9SxPSzYukBuXkzj4ZLmc8Ej8KszVl/XxVfTDj/ICzurV39Njk/N6gCteT0DF/Jzq13gsUcjiyjzsMB7AZTD6C6OKEj4ts234YU5gFp7I18OXluf2vRMBANlMwgv0B5peL1Nr3Tk7BZ7HwKcTZval5WRPYwTwNVXOvGjmMcQnGTj2iYqxapRZEtzqqPUsozm1EIsgrhMeEhD+NGBIwKueeSz2y6SrOiGIGZmurKopISMbsL4MA4j4w2E/uR7hKb20nJIfuZ7ETNRaL4xeX1pGeQEoZ8CEhFFd08Wleo3+SfB6xpMSfKmYDZrAVT/MvlySn6xe6k6aS0uqM+Mke8REeUoeXFO+glR9lZZUXKIm5zJg1VcRXFoyIMtPFvBZPe0cfrzYT1j10jKrAGM/kY/4yOp15UWvOi0nAUbVyok9lZ7EwRk/fMCq0/JoQJhRqp7w6IecVcMM+JBaJg1wdif85Jxkj02v0hL7KeRRxSUZ8UkzXsiXWl6lbILGMwzgFKp6qtkkGyPg0rKmU2bwKYqWV1E+iT+1bFNhT8SXzZmd0jVFuRIQfNp5ahn7yAacLCc8gsfezQz0aYl4LJNzUp7wCC7gUQICyGkBJyfKy2k5UxPgkjztJ/tomgxpGUEqlAEjDoglAyb8onmy+jguj2Tsf+zCRZbSUzVjPMgYPC0ITpr68tgjmBMZXHlKSNQoNy6XbY/VsTw2DXPaOSRgfQFfclR76vjJFXwZfErAeNJ/+H2h6+J9jZZTBa60PPZ0GDSD8YJhTuRVrUeQp6In+3LRRP8k8hT1GnjjQkh2Ip5AGM8m5OMAjfI1KtBcfTm5yPScmBM+xQ4/gG2DUa4vcZ1Np5wTfqQhc+wnPK5Ly/CA0U/CWlOkHKApeDQnZh57wfGSK8pHrknIIfw5zABctcIEKKdlZJ9s4FVKCXyNx1L3qh49yoyS5RjXnYSkpaZLoOqFwfLhYJJ8lq/qy6O0k+tJdRGcjOvvq7RcMuvshTwFXk1X9VVaZgb6J1WEr+kkPuVvIF6l5SQenPHjDR9y/bqq0ZxaEEQzsQAYYsNXqvasqi8nmVYYz+SkBY0nmWt6FRDlWsmM8cKPAHD2ZfjRAwipmrTRDKh1LMMDID0yk0sL5EkW/CTyVIrg9M+jhEvLI8GlaF17FeWJzzzlxuVsveDT7LPpyT4ieMS/7MuH06Uxv0orBa9p0vqklonN6Y8SRvBkByB7xnhqwTs8YOMaUvVsfUfS83//hX/hXxhvBpwlgmm58KoXJrEhT5hqyExa6+yJMi0XctVLkWpNgMHJgYcBV4asqml9GT5g0lQ3OkNO4Klr8gGDk/Z6LWiSHPHKS8jwAS+UfKDXkKeWk/eCq54aYa4qaY+YGK9xOYVjObkIOVtPuC4or2k5/NmXIU8gaZWUPgo/MSM+kaS9SsvVi1MveEIGjGAYQM+rtGiNa8YPE37sEz78ham6vsz4+fakPxfdo+rLhmlMptZpOWXG9SQvss/zeT7PT/kpP+XH//gf/yW+xJeIZtIm/5QGTgvik/6kSULlib/gJOhycmodfPYlyjVdciha0+f4HJ/jZ/yMn/FLfskv+bJf9ssiG/7RmDCXFiyPlFP32DTDUnQR1Ir91DJpj8Cs1XQKvMSe1cgqX6XlFHV5m6JLmmqYuEawalpWfezF/08wV1+mDjCYohP+kHovYtUPpy8nV555VBTNSXnSpOWR91X0Jy8Y2SgfhYz48hj8iAckbSyvAv7tf/vf/hf/xX/xydazL6dYxFUlu8PP7ICzHE1aNNWanHNchh9lwJNlYpUv5H0grzjVren0GAJPukYQkqIzVp9wd44nBvHkAFTXl4TDDADvibhyZmjdgj21PEo4+4KLHE9kGMFTFPLErOkRuBSdWl4j4eJKbPRPNkWwplPLo0kncSyTPAloTvhkmcDXa8E+Cafnn5Q8mYDLLS+1vF7ZyT8R1Mv1f/gP/+E/+kf/6C/2xb7YRfNk9dQy66O8qq9BPikZ+yScWuZ6XCMIeM7w4hQxjwCCK09FKGfSC747lTlZIjg11vr6NwBTMUaYqy9bbDMm4LE8pSXwSasg2ZYWwGhiOcWug4+SkT22RnYSpwVmxANORYNPXhH/u3yX7yIR/FE/6kd9ra/1tRb9kxZl5emxiTqBR6VPYk7tY5+i9UXTtfwi/pf+pX/pIz7iIz7TZ/pM4w04dQ0mwdwg/Bxfra/vy2nhCVM0yZf2VdcRmEuLplo/HDkJTN3JGO/kXFpOfBLOslYynyQLv6b1FIb3rr4kNuKxQJ7wqXrwxMIMxgX2nFogPeccUJ0cwL/6r/6rX+2rfbVf82t+zd//+3//H/2jfyRg/lv/1r91EhB4sdR6ajnpwdGP61/5V/6Vz/AZPsPn/Jyf8/t+3+/7A3/gD/x6X+/rfapP9aliQTP5Gfm8A+/cO6JJS5hL8pDDm9VW4n/0H/1HYv7/8r/8L3i/xbf4FimqnEljCZ+Ws3Vc0zKrAOdCiDLjn5Qw5z967KJXvTCEUzfk4GFmZ4B1mpYpHUEsYwSccGRDromcOo4AvD0iLWt6ZJzeVwGk4ZqEiywD5rGqj526uJ6sxlvTCY94WoZBNsoLfjQ4zOhfpehRy6so4S9pM+xJAPGsepWWR12vUWFf+EE/6Af92l/7a//1f/1fvzTGlZYNB+2TFvIsa4oGPFMvyaqaztbHvtQ6XZeEE3/CyKqe5XinZd2pKV2PZa2JOu2B8cAMOQB+Wi5jZsaj2DUldtUT0KRaCTi1nGTgF3Z94KL8aj2rpHnQK4evL8l51LLW0Z+8QwY8El8EVU8tkzYAzQWrTvI1lBfxWU3LJSoD/tlL9mRSWl5+I0SZemWmqLL4//w//0/U/9f/9X+lOEAT/Df+xt/43/l3/p0/+2f/7B/8g39wZk3CMI8AdshRkk9aZNOVMZABsUQzxqqjrDqasxVNkieNEH1JrG4C6lqlVuwZFo1qAGSSK2dw1Ys4LZr+j//j/4hAzvEf/of/4ef+3J87Of/r//q//vJf/suVWlOahMmBR3laGyY7xcQv9aW+1P/2v/1vv/f3/t6/+3f/bhJOOYhPUao9kMjWRNpZTQLKyGqt+pL/xT9jD1k1Oae0kyX4dNqlKyFKXfvqX/2rSwT/g//gP/gqX+WrfLbP9tnwfskv+SX/9t/+27/7d/9uBB4YY8c8T/5MfvN2wzqTaj1LjP/mv/lvfuRHfqRx+RN/4k/84T/8h//O3/k7CC77VcMoY6/avA2e2C/0hb7Qd/pO3+kf/IN/8Bt+w2+wOv7W3/pbslitdI03OEawvhClOsykBeSxWuc9VwnmVX4YPZpp0fT5Pt/n+9Jf+ku7O2cPfBIqxxIwsVXnXoZJsJrAMwMNAtXIKifwqg4PoNqjy8pHsjBz1MkIHj7eVTVN1MQO0Np8QNNsSeyn+3Sf7t/4N/6Nz/7ZP/s/+Sf/hCf/3t/7e3/pL/0lsyuxY6/LeD3geNcK2cABIIcP+PSf/tN7ZfE1v+bXjOvf//f//c/4GT/jX/krf2VCzi6EvMq0j/65ES+mfSX8Z/2sn/X
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1090x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-124257.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=512,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=32,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 28,
2024-04-08 11:37:01 +02:00
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-22T16:16:33.646107Z",
"iopub.status.busy": "2023-02-22T16:16:33.645729Z",
"iopub.status.idle": "2023-02-22T16:16:59.754676Z",
"shell.execute_reply": "2023-02-22T16:16:59.753710Z",
"shell.execute_reply.started": "2023-02-22T16:16:33.646076Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [00:19<00:00, 52.08it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCACKAIoDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KPxS8ZeHPiPq2k6VrP2exg8ny4vssL7d0KMeWQk8knk1x//AAu34h/9DD/5JW//AMbqx8X7K41H42atZ2sZknne2SNB1JMEddCngbwR8O7OK48c3Z1LVpF3Lpts/wAqfXBBPPcnHtQBy/8Awuv4ibd3/CQHHr9it/8A43Sf8Lt+If8A0MP/AJJW/wD8brorv4x+H5IjYxeAdLOnZwUZVDFR0OQOD71mfEnwbo9p4f0rxfoEU1pZaoMtZSciJiM8H09qAOh8RfFLxlY/DjwXq1trOy+1L7d9rl+ywnzPLmCpwUwMA44Az3rj/wDhdvxD/wChh/8AJK3/APjdHi3/AJJD8Ov+4n/6PWuAoA9asviD8WtQ8MX3iG21kNp1i6pO/wBktwQT6Dy+cZGfqKxP+F2/EP8A6GH/AMkrf/43XSeB5A/wC8ZwLwySKxJ5HOz/AArxygD2/wAO/FHxlffDjxpq1zrO++037D9kl+ywjy/MmKvwEwcgY5Bx2rj/APhdvxD/AOhh/wDJK3/+N0eEv+SQ/EX/ALhv/o9q4FEaR1RFLMxwABkk0Ad9/wALt+If/Qw/+SVv/wDG6P8AhdvxD/6GH/ySt/8A43Xb2Wj/AAz8EaRZ6f4ttTf626Ca6MAd/IzjAbDAADNch8WPBukeH7rTtW8Osp0bU498IEhfB74zzjmgDqPhb8UvGXiP4jaVpWrax9osZ/O8yL7LCm7bC7DlUBHIB61x/wDwu34h/wDQw/8Aklb/APxuj4Jf8le0L/t4/wDRElef0Aegf8Lt+If/AEMP/klb/wDxupv+Fw/E3yvM/tiby8Z3f2fBjH18usn4X6NYa98QdNsdSUPbEs5jPSQqCQp9iRivWdC+LN9qXj9vDV9okNtocjvaC0W3LspwVAOOME9eMc0Aebf8Lt+If/Qw/wDklb//ABuvr+viz4k6PYaF4/1XTtMCraRS/IikkJkZK8+nSvtOgD5H+MN5c6f8adYu7Od4LiJrdo5IzhlP2ePkGvPbi4mu7iS4uJnmmkO55JGLMx9ST1r1H4i2+mXX7Qt5BrM4g055rZbiUnG1PIjyc1w3ii10q08VXVvpd+t7pqyKEnjTaCuBnH09e9AG78MvBB8W679ovQY9Fsf3t5MeBgc7c+pqz8VPH8Pi/U4LHSohBomn5S3jUBQ57tgdBjGBW74x8faPovg628HeBpf9EeMNeXeDvYsMld3c84PpjA6V5BQB3/i3/kkPw6/7iX/o9a4Cu/8AFv8AySH4df8AcT/9KFrgKAPbPhvbwD4H+MZ5EZ/MYqyr/sgEfq1eJ13nhz4kNoHw71fwqlgWe+ZnS6VwCm4KCCCDkYX9a4OgDv8Awl/ySH4i/wDcN/8AShq53wVd2Nh410i81KQR2dvcpLKxUnhTnoPpXReEv+SQ/EX/ALhv/o9q8/oA9z+Gvinwrdaz4h0zWzJd3WvXmyOWSHImjJOFJH3ecfjirnx6/sPRvCWheGbKIxz28nmQIOfLiCkEEnnkkfl7Vz3wUn8L6NBrPiPX5YluLHYtuJCCQGzkqvdsgDPvXC+OPFt14z8T3Oq3BIjJ2QRnH7uMdBwOaANz4Jf8le0L/t4/9J5K4AAsQAMk8AV3/wAEv+SvaF/28f8ApPJXDWk4tr2CdkDiORXKnvg5xQB7x4C8CaL4D1DRtV8W6gI9avnVbKx6eUzHALfTPPYVT8dfGTWNG8Q6vpOlaPY6fcxTtC16Y90zgHG705685rb1TU/hb411fTvE+qa/LbXUCoTZO+FOOdrDaeM+hFeOfEfXrPxJ471LVNPJNrI4WNiu3cAMZx+FAHM3FxNeXMtzcSvLPKxd5HOSzHkkmvvqvgCvv+gD5w+KXiLwbY/EfVbbVvAn9qXyeT5l5/a80HmZhQj5FGBgED3xnvXH/wDCW/Dz/omH/lfuP8KPjb/yV7Xf+3f/ANJ468/oA9A/4S34ef8ARMP/ACv3H+FH/CW/Dz/omH/lfuP8KvfBvwamta82uaoipo2mfvZHmX5HYDpzxgdTXG+MbrS77xhql1osXl6dLOzQrjHHcgdgTkgdgRQB6p4i8ReDYfhz4LubnwJ9osZ/t32Sz/teZPsu2YB/nAy+4889OgrkP+Et+Hn/AETD/wAr9x/hSeLf+SQ/Dr/uJf8Ao9as/C74f2Hiq31fVNcklg0mxhI81GA/eYz+gHT3FAFb/hLfh5/0TD/yv3H+FL/wlvw8/wCiYf8AlfuP8K4O5WJbmVYHLwhyI2YYJXPBP4V682mafo37OH2y5s7f7bqd2PJmdAz9eMHqOEb/ACaALvh3xF4Nm+HHjS5tvAn2exg+w/a7P+15n+1bpiE+cjKbTzx16GuP/wCEt+Hn/RMP/K/cf4UeEv8AkkPxF/7hv/pQ1cfo2i6h4g1SHTtMtnuLmU4CqOB7k9h70Adh/wAJb8PP+iYf+V+4/wDiaP8AhLfh5/0TD/yv3H/xNaw/Z78akZ8zSxwDj7Q35fcrmPFXwz8TeDrRLzVLSP7KxAM0Mm5VJ6A9DQB6D8LfEXg2++I+k22k+BP7Lvn87y7z+15p/LxC5PyMMHIBHtnPauP/AOEt+Hn/AETD/wAr9x/8TR8Ev+SvaF/28f8AoiSuA70Ad/8A8Jb8PP8AomH/AJX7j/Cj/hLfh5/0TD/yv3H+FTfDnwv4Sv8ASr7X/FureRaWUojFopw0pIyD6n0wB+Nen3vg34d6l8PdV1u28PTaXBHbu9vdSlg7YXKsqlznPocUAeVf8Jb8PP8AomH/AJX7j/Cvr+vgGvv6gD5A+Nv/ACV7Xf8At3/9ER1yvhvQLzxPr9ppFioaadwMngKvcn2Arqvjb/yV7Xf+3f8A9ER11PwCn0fS5tb1jUru3hlgiCp5rqpxyTtB5OfagA+KniXT/DGhQfDvwywSC3X/AImEic+YxGducnknk/XHavF60Ne1BdW8Q6lqKLtS6uZJlBGMBmJH86zqAPQPFv8AySH4df8AcS/9HrXqnhz/AIRvWfhzeeA/DOog6p9hFzKdpxM5wWGe+DtU9O3XmvK/Fv8AySH4df8AcS/9HrXH6NrepeHtRXUNJvJLS6UFRImM4PUc8UAd9o/wL8XahcAXcVvZQBwrvLJyR/sgDmt7472UGgaL4R8PWTOttaQzEoSTk/Jgn8d35153P468Ravq1lcavrV5OsMqk4fZgZGeFwOgrvfj5rVrqt14eWzureeFbVpP3bBmUtt4Ygn0HH1oA5rwl/
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAIoAAACKCAIAAAD65UUzAAA85UlEQVR4Ae3ddbSuW1k2cLFRFFQwsDBABbsDde+DYoAIJgYGJooO7ESO+wjYw8AETMQOVFQETm0/BezC7u7u4Put93r2te41n3fvc/T7xvCM4Xn+mOue97zumB3PfN51m+c85znPdetzSy2B576lOnarX4cS0Hs8t7nNbcSEz/3cW4WFDv95nud5pCYpnH3hVbBJ5RCJFUlRFSKqhFM/qZqIhuqJZlH4ytYcZq2UOVVFT2VFn/d5nzeeBL/XuddDZG8l/gPTkAdNP7oawgmmYVILK4G/Wdkbk1Z7zcyUBIhtzPCnK+gAGiJqBX6CJeWJnj2NE5ENd/ZPpCpbK2dRJ7FikpSKWZiJNu9Vgu9JvhBHreAHX2KvB4CSo3xJFUTDbFamsed7vueLjSWk1IOZsKnViChgwQQ8rQQQLwl6YBJKQkhKtIYKCKe2wILE8UwrlQ0g0QP8tEUnKeHRUpuAKoyVJJVZAn8++KJNDRFO+XsCbMvLkiWZbMEhRKOx3k9dmKLFJGkC0InOLFXVdDqwcp7/+Z//tV7rtR796Ee/3du93Qu+4AuGDzNbPfpE+6XMt3pwponS9bNWJGEu+NoqDFHZWpmp/z26nh8VX6sn6LgbAZy6hTOjU3UwyWotAUzM0gjASBGBqYkQkbrDHe5w4cKF//zP/3zWs571eq/3elUbNyJ4YmD0PEmzESSpgrWyiOxhC6caKrjPSzEhLqdhgV05Gisnc2MecYQSkZPQiP/4j/9gLNEIBKx13/72t0f/7d/+7T//8z8jIIUFoyOF6NNUBEPhF4bA9yD0mLd5m7dBv9EbvdGrvdqr/fRP/3TAXCLoqVSVtwKSFG2hm6OAY2KxnijNVbiATzwbbksVFcYEuhbDkVS7OC/2Yi/28i//8i/90i9tBvm3f/s3xfWv//qvz3jGM9BTT4w2PK2esKYHKfEaA7jtbW/LwMu+7Mu+xEu8hEYt+hM/8RMa+N/8zd+ACav3KAGj4VOLiE9gCA+msOUOYOz693//dwD50RpkBp3ik5qcI2poKVmAcEoEWRHm+B+LvPqXf/mX5BdMUmHVH46kPUdSUhMWUJde6qVe6r3f+73vd7/7nTt3rqmK613e5V1+9Ed/VPuO2r3drXqahyIQoZnkvTLSXd78zd/8gz/4g9/xHd+xNhD/5/D81V/91Vd91VcJkxRHqaWkXqIVwawhMM+JpZFtGoj8/u///uu+7uuyq92lemjzpBCrM+aiM3TCAkrgxwqLFN71rnel/8/+7M/I3u52t/vVX/3VX/u1X/vHf/xHMAAhW2SJJJpw8bNgxOUeA/UnfuInfuRHfuQLvMALpJHRTA+j3/RN3/TQhz70B37gB/7pn/6JeEwgTq1g7bmnyQebb/Zmb/aEJzzhj/7oj9SzFh2RhspL0/u7v/u7G2644d73vne8XDQEnKSEXJxRdJxWWOg73/nOP/ZjP6bjE3zP93zPBTyjoZlDxMo0jU40YSzq+iY2JeX5h3/4h7//+79XOn/+53/+xCc+8Z73vGe9iieNhqAnVpI61RaJWf6LvuiLatCmALYIXnvttd/4jd943XXX/e7v/q66V5gaxPu+7/u+yIu8SMVDxMqZwQ1LJokhAjKCffInf/JbvuVbvviLv7jKx1QZP/VTP6XRqQ8NnOG73OUur//6r/+ar/maavExj3nM533e533bt32bnsvF6JmlSUNMsBJ+iBglgoBB8JstQ5xKClOJUIj2AHjiZMJGYdAxnTAKhYnK1Fu/9VvrlzgJKdSfHvjAB3Lpj//4j3/91389TkZzxKdaHCJNncQiqHDe533eR+nLhdHlK7/yK//gD/5AVb3TO72TepJBA6xyMwApz+ppXrbqaX5Y5TFdyuJTPuVT7n//+7/6q7/6C7/wC5PkNy3q5mlPe5pupMVpevCq7WVe5mXOnz//sIc9zIhx9dVXk/3Wb/1WA13ULjlJGcFwLn4EMDMGw10cdg0OTGRugPekpCRVEE2bsE9goghhfdDO7nOf+7zJm7xJkQpL3STK/7vf/e6qp/jw43NFEPVhMtEEa1rUMPAWb/EWv/3bv60yvvqrv1qDDv4v//IvAxZq2Sop/IQ1t1aPZGmv/dqv/aAHPUi1K3cc7uqPN95448///M//4R/+oa4zdaF/53d+B1/eTIAc0mTSKmMmBTRFcPal2dqCVPds2fTQo+lZyKmeaGjmWxApqWkFR9UKNQKwAHCYeI/3eA8Tsvr+zd/8za//+q/XR//kT/5Ebb3bu72bCdwq8f3e7/1+67d+i/UaaruZJjhDrbCwuLfwU83GgOuvv751A/PsZz/7a7/2a9/wDd/Q9k6JdXBLq62qrXqSkxTZne50p3d913f9+I//+IDUisHqKU95ikqqGCJO13V19nu/93vxxgrSfP6nf/qnSU02KpvCkquZMTRAkYi//uu/jgiv4lgBAdcB0QIiIown4UNSiFbN973vfV/u5V7uR37kR77ma75GAQX/fd/3ffz/wA/8wFd8xVc09L3yK7+y6qkzUVWdNRGisKN8Q4hye6VXeiWdMou0wPSnL/zCL3z84x9PXLkZh8KPw8kgzmnvKYt/mpi0X/mVX/nlX/7lRz7ykT/+4z8eYSEYja2V8g2AcqXyccyH2RVVZ2ElaJhlmiJIOdJvztTAIy4pqWRxPAA4wDgl8KM8AHSdhEnqO7/zO7/Ga7yGBveIRzxCeMc73tHAoib+4i/+wpSp03/AB3wA04a7qMriPrKMhjnDmKihmoYJbRD7uZ/7OePKu7/7u5sXfvEXf9FAas1mWNJ1TBxg3//9329CmmpLn/YeboVrZ2Ntw/vP/dzPffKTnxzmNIzT8kqqvnmve93rVV7lVcBwfvInf9Jwh1jaQsBVFWLmrWqVhbVN8CpeloxCUb4vpkiVTyG7wNUWPS/5ki9pdjQ2mFPlzlD2ER/xERY+n/EZn/HYxz5WOVq8QcqL0vzBH/xBdDVQ3pqINmFNTFhT44YZ9Dd+4zdUsxnIcvHpT3+67BjQDKccoPYXfuEXvvd7v7cbkog3LyfjSSO0SNaIzp8/LxtoNiIQQsgnXcSTwkpoIcTAQdNz7B7e+I3fOFLVkKQyQ0S2NOV5cLSvj/qojzLfEFR8cWZa5EAFQwj3VmoC3i7HpKLclQ6wUwmdRslqByZaA/KXf/mXR4NmXp0lar1WyqmVgkMwypYhVHf0aDRkEyKUmNOQ13md1zkqvlmpsUV1okclF6Tp9HGPexw91nuceMhDHpKJLiWoxCmJFUSiwqmkzBCSDDv6Lm0Ev/3bv93Cvf5Ml6IHR8OqlakZHbz+Z0KlzSYD0/D7aZ/2aaJMWIaY59RcW5ghfVEyo8nL5FyOfoM3eAN7KfozYLZiaGBC332FV3iFtrMqSaZiZRvckpYclkYACQm0/yoIZoKRKp8PeMADtD4cTlhz2wNnCR9ZgikgAESjdMYDzPCnIVY6kzlDeqEXeqGIQ0atqAeNg+jIjD6KAbBUE5r8ASyvDcUBKyC50F/xuccx0XPnzllrAfQBa8bDXAwVieCwYvmgD/ogY5qWkYox5ttr66wGGBOSjUqPcyLLNJ3TypnqAapJOU8JYqZEktSCEJV0/vz593qv9+KBjH3P93zPZ3/2Z1uH4Ed2yVL1VGcIspNAW09nQENn84ugM65PJ2MCxwPjSZ01DMeazT6BIQsfFWMbkLkt4zkMPbKWqIJzxrM4P0vtxMylYgk9w3vc4x4f+7Efa/A0iysZSVTR+Umf9ElqxYitP1VbywTBveYiCs9UD0R9UhZFp1yiKMYkib7qq76q7YJNkoxZRHg9w3z0AiDiBPEwoyFJOHFlptaEJAUKKVWr7yZO8WF6iIdmAtgTJn6yQLAcTB6aaX74h39Yn3DSYaFh42xkswHQUXQpS83s1VTMF33RF6k/Uh6ahXXswFsDmALo+YRP+ATrNB2IbzbyDsOUko2nvQ4fVuFDvOIlNph4WGzIWIVTanFu0gBhmh5kw3aBuDGty4GAwaa2g5FtnIw4QJEh4Jtka2I
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=138x138>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-124847.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=16,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=4,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 29,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:06:41.407250Z",
"start_time": "2023-02-13T14:03:27.834241Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:21:11.522425Z",
"iopub.status.busy": "2023-02-22T16:21:11.519897Z",
"iopub.status.idle": "2023-02-22T16:23:35.317409Z",
"shell.execute_reply": "2023-02-22T16:23:35.316345Z",
"shell.execute_reply.started": "2023-02-22T16:21:11.522385Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-124907.mp4\n"
2024-04-08 11:37:01 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [01:32<00:00, 10.82it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIiARIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPquk6TrP2eyg8ny4vssL7d0KMeWQk8sTya4//AIXb8Q/+hh/8krf/AON0fG3/AJK9rv8A27/+iI68/oA9Rs/if8WNQ0y71K01KeaytP8Aj4nSwtysf1Oys7/hdvxD/wChh/8AJK3/APjdb8l5H4P+AUFpC6G+8SSu8gPOIwdpx+Cr+teQ9qAPb/EPxR8Z2Pw58F6rbaxsvtS+3fapPssJ8zy5gqcFMDAOOAM96xLD4j/F/VHVbG7vJyxwuzTYeT/3796zvE8rQfCj4bTRkb0bUWXIyMi4U9Krt8VfiBqsiWkGsXGXwiQ20KL6cDauewoA2dX+I/xe0HH9q3l1ZhjgGXToAP8A0XWV/wALt+If/Qw/+SVv/wDG67jUtX1Dwv8ADG+g8c3r6jq+sAfZNNuMFrYcguT27fkPevB6APcPDvxR8ZX3w48aatc6xvvtN+w/ZJfssI8vzJir8BMHIGOQcdq4/wD4Xb8Q/wDoYf8AySt//jdHhL/kkPxF/wC4b/6PasTwF4cTxZ4203R5WAimkLSjdtJRQWYD3wDQB0KfGL4mSLuTWZWXrldPgI/9F1E3xr+IqMVbXyrDqDZW4P8A6Lr0Lx78WLrwRrR8L+GNIs7eGzREMkkRz2OFAIGO2ec5qHSPhrq/jfxPa+JfGFjaW+nXVqzyQ2uYiMDC7h2Jzn8KAKXwt+KXjLxH8RtK0nVtZ+0WM/neZF9mhTdthdhyqAjkA9a4/wD4XZ8Q/wDoYf8AySt//jdanw2srLTf2grOy065FzZwT3SQzD+NRBJzXmVraz3tzHbW0bSzSHaiL1JoA7ofGz4iEgDxBknsLK3/APjdK/xq+IqMVfXyrDgg2NuCP/IddtoWleH/AISeGrPVvF2mC81u/JMdq6K5gAGRweh9T61LceLPh78TFfTLnS/7I1fUCAL0xoWWQH5cvgZB70AcF/wu34h/9DD/AOSVv/8AG67D4pfFHxl4c+I2q6VpOsfZ7GDyfLi+ywvt3Qox5ZCTyxPWvMfGHg3VPBWsf2fqaKQ674ZkOUlX1Fb3xt/5K9rn/bv/AOiI6AJIfjN8SLmZYoNceSRjhUSxgJP/AJDrtYr/AOOz6Nc6nLe/ZobdDIyT2tursAM/Kvl81x/w98eaL4J0q4li0Jr7xFMxWKZiMIuOMcE9fTrXa6Fq3jZ7i88aeM57m20qztWMNmzGGOd2GFATofxzQBwP/C7fiH/0MP8A5JW//wAbrsPEXxS8ZWPw48F6tbaxsvtS+3fa5fssJ8zy5gqcFMDAOOAM968TmkEs8kgQIHYttXoMnoK7zxb/AMkh+HX/AHEv/R60AH/C7fiH/wBDD/5JW/8A8bpV+NXxFdgqa+WY9ALK3JP/AJDrgY9pkUOSqEjcQMkCvoyw0/TfB0tjp3hPwbPqusXNmtyNRu8bApXIfceBgnoMdKAPOpfi18UoFLTapcRgDJLadCOPX/V103h34peM774ceNNWudZ332m/Yfskv2WEeX5kxV+AmDkDHIOO1dRrnxGHhHw7NaeItXg1zX54yRZ28KCGHcBgMcdueuc15V4XkM3wn+JMpCgu2nMQowBm4boKAGf8Lt+If/Qw/wDklb//ABuj/hdvxD/6GH/ySt//AI3XP+GPBmteL5bpNIt1k+yxmSVnbaoHpn19qj8N+FtQ8TeJoNDtE/fPJtkfqI1B+Zj7CgDs0+KfxWk0aTV11OY6dHII3ufsFvsDHtny66H4W/FHxl4j+I2laVq2s/aLGfzvMi+ywpu2wuw5VARyoPBrV+LeqaJ4L+H8HgLSUQzyhC4AB2KCGLt/tMR+vtXnXwS/5K9oX/bx/wCk8lAB/wALt+If/Qw/+SVv/wDG6P8AhdvxD/6GH/ySt/8A43Xn9ei/BrwfB4s8Zo15k2dgBcSJjIkIPCn2z19qAEf4y/EmJFeTW5ERvus1hAAf/IdNf40/EeIgSa8yEjIDWMAyP+/de0eItOvfGnjXTNKXRYB4X0S5ElxNLlVlcDBRQOMDpj+leT/HHxFo2u+KLSHRmR0sITbyvGgClg3AUjqAOKAPq2iiigD5A+Nv/JXtd/7d/wD0RHXAV3/xt/5K9rv/AG7/APpPHXn9AHrXxI8O3Vl8NfBN3FGTZx2QEhUcI8n7zJPXnd+leS16Z45+I1l4j8BeHNAsY5o5LOFFu94wCyIFGPXpn8a8zoA9A8W/8kh+HX/cS/8AR61zPhfxPqHhHWBqmmeULkRtGDIgYAHHb14rpvFv/JIfh1/3Ev8A0etcn4e1Cw0vW7e81PTV1K0jJL2zOVD/AI0Ad0nxRtfEqva+PtJi1KIpthurZBFPCeehHUc9Ki174ZWn/CJN4s8Kau2paQhxLHMmyWHpnPY4yOK1rb4l/DyA4Hw3hAL7+Zlfn8V6e3Sszxp8XJdf0V9B0TSLbRtHkIMkUSjc59OAABwOgzx1oAp+Ev8AkkPxF/7hn/o9q43R9Wu9C1e11Sxk8u5tn3o2M8//AKq7Lwl/ySH4i/8AcM/9KGrgKAPs0R6PPoVh4n8WWOlx3ccCNJcSRhhHkjHJ964n47eLb/TPDWn2+j3apa6oGWWSMZLpjoD2Bpg8beBfiH4JstP8R6v/AGWYnjae237N5XtnB+U+2DXn/wAYvG2ja8dL0Pw8Q+m6ap/egHazYwApPOAM8980AZXwS/5K9oX/AG8f+iJK5DQri3tdf0+4u1ZraO4RpVU4JUEZwa6/4J/8le0L/t4/9J5K4AHFAHuX7SCIdV0OXc5c27DGflxu/nXhyI0jqiKWZjgADJJr18eMPCXjzw7pdj40vb6yu9NQjzLaNT53YckHHHar9n4v+FXgmM6h4c0m41HVePKN0D+7PY5bIUj2GfegDN+MbvaeG/BukXpMmpwWCvO5OT0AwfxrB+Nv/JXtc/7d/wD0RHXKeIvEOo+KNZn1TVJzLcStnH8KDsqjsBXV/G3/AJK9rn/bv/6IjoA9V+GPw+tPDeh2Wrz2aajrepRLPau0e6O1UqCMk9DznP4VnftB2uvXFtpUUEN1cadCheeWNfkMpOPmA74H61w+j/HDxRovhiHRbdLR/ITy4rqRC0iKOg64OOg4pvhL41eI/Dkl0LthqsN1KZXW6ZiVY9dpzwPbpQByvh/wXrniVrpdOs2ZraEzOJMrlR2GeproPFv/ACSH4df9xL/0etbnir48atrmjS6bpdhFpaTZWWWNiXK+intx1NYfi3/kkPw6/wC4l/
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARIAAAIiCAIAAACLx9BEAAEAAElEQVR4Aez9Z6Cu2ZbWdQsoigLmhAJmzBlFPcDe54AZQYKiiIggjQgK3dAi3U3XqaNgI2CgW1sbA6ASDNioJLFOV5UBUFAxY074mrP4KsL7W+v/7GvPup9n71Mtfnvr/jDXmGOOcY0xc3jmfa9v8lt+y2/5bT55PimBT0rgG1MC3/QbI/yJ7Ccl8EkJPJeA2cbzTb7JN6k8RnzTb/pNR48gM5pA0W/2zb4ZZtHoC5RoVuKnRcUTPQ7it/1tf1thzynwive2v6eVh7o5GcTp9slPkQwB4aJxRE8rKf52v91vl6To5BEXhJLOkNb5SPKEmZXTsejTxKkbPQTEJRWnLOdVqcvLDKV48TzhTGfiAv726KxMLCuFmBdiqfNqig8dK/XeyrSG/zE5Wcn66QOcm5W3GMtGasvJDCOGLvVeYEySD61cHEr+NEcR8xTL+mkLnRunFSqTiUjxhDr5pQrHHHGfdMnLJEdQYYhX40QsHD/wet3GC6nUs3I6PE9iniArgZP5JhpOSbSWlxmKEHoGm+kTMJCTEy2kePJJzko4D8OpUH8ocDInfDIfWlkWUrlXHOeeOMFH36z0h869uwHN8DRJVscJVOtS7xGmkpW3CEzynmCFuWydVqANUKrnvuAmcA+LkxZCHtH3MpghnKkXKwkII06QONONYCv+KS/Jg1Mov1k5M57DhcQWPQn025+LS5e8BDuESxT/wrlkYYonH/NiBUe+TpmTXtZOWyd9b6UWmJWHklP5gsTpCfrS+EVvVvyBNV9Fo+NfzAD6zb/5Nyd/ClC5j7Lxf//f/3cIpeZECPEvihdzp6F7Sc4EO9MR+X8P9XE4s4LwnK6mPqaKlzrTJzjHiMk7md/0m36TJFHh3ENUOPhDuMhkGpRnxXhaAY4/dUkkf7ff7Xf70//0P/07fsfv+Hv+nr/nv/wv/8u/8Bf+wv/uv/vvTq3o2Z1L41wEYPLkknqRGX7CopOPkxVZvpTnvTzdMQc7tHEi8EsK/5KXC84leoFaNLFgxzyjWXm9kSB0mp8Ogho4lVS2Q/nj/rg/7k/4E/6Eb/EtvoXq+Zf+pX/pAi26il9hwc/EwEW/3bf7dv/n//l//h6/x+/xu//uv/sf/Af/wUaO//K//C+p/0f/0X+E+G/+m/8m5BxIsbyBXdJMDBmR2IiET4F7OvcoSgqTVghx5j9i+JKKElZKC+szpQp7aK0khybpIY0Zf85HCIGDKvX3+/1+P71FSf6Bf+Af+Kf8KX+KbgPwT/6T/+R/89/8N//Ff/FfXO+9efDoD5yyM0NFK4TThwnEBP47/o6/oxz9b//b/0ZY9f1f/9f/VdIlU2cdBTKxeYR/isUn9jv/zr+z7PyRf+Qf+at+1a/S3oixiO+Z7oiYcM4cic7zJC/RmPdi+PdWPtJtSGRpY2E61MqMVLQH//t//+//Y3/sj1Vqf9ff9XfpNjie1BHJIPL+Ke14ftff9Xf99t/+2xsRdRX1/T//z//zH/QH/UH6zJ/0J/1JTOstiv5X/+pf/c//8//8L/tlv+zf+/f+PU1EMR0AT+RGr4uJnCSQz/c18Xv/3r+3dtbMzpP/76vnP/vP/jNjM0NTnO69dRnM7oRxyLOe8EqAod/n9/l95O4/+A/+AyYuqau8lW2FX2bRwQrjB1vItAI0fn3mM59RHcowrVIV73f+zt9Zt1nvna0EWEy+sIwMP0LStBCi8X+H3+F3+Nbf+lv/oX/oH/pH/9F/9O/0O/1OoP7tf/vf5qEu9G/9W/+WbP4P/8P/kPDFSmizkumZuBilm1fGgp/6U3/qH/vH/rFf8zVf86//6//6b/yNv3EqiNGhLTxzlMwA803VNFjLiyQV9O/8O/9OPmifyvO//+//+1//6389c8O8EYSSuyY894GTyfAZ/fqv/3pwdH/Gz/gZJ595fWkcUfSzkdfndar5l//yX46pDQllT6ig/4v/4r/4/zw///V//V/jiP7cn/tzv9N3+k7f/Jt/8wG+iWCIiudNAvEtY/6Gv+FvULX//r//7ysm8nxg+tf+2l/7t/wtf4t++1D9zPuzkdd5IV8eNZp0ETgelaH0/7w/78/7aT/tp9H6O/6Ov+MP+8P+MJ3WtKAj3RtKER+RlWHeC+PoGD/hJ/yE/+Q/+U8SFhoB9Pxf9+t+nbalyt977z3jQrr8qRjlBV2OUjzB4xc+5eHV+j4Oyd/+t//tDTrf83t+z7/5b/6bK8BZ/9//9//dkPfhhx9+8Rd/sYwTfmjlNDHYex+mbiViaGZF/zfV/AF/wB9QsVx07/MCgf8n8knLyKc//emf+BN/4n/8H//H6f5j/9g/llEmNAbmzAfaalofyctpbH5ELJracxk+BaLWZr/iV/yKGv1XfdVXJSCclxfdrGCWYcWqfH/Db/gNphS1rlB+6S/9pT/qR/2oL/qiL/qSL/mSH/ADfsBf9Bf9RdblGgHFn/Nzfo7ZeT4g4HhmNJrpMy9Pjr4qsgkjDFr/y//yvyQp1GP/1//1f21doZj0qGC1MA1ONmdlRLqgPAaICKkRwiS1m7/2r/1rZTD58vKv/Cv/yr/wL/wL2rRRE/ipUsmsQaQV1CUjosS03X/in/gn+EyS//BN2v/UP/VP/fl//p9Py/r585//vBKu1udVgIPNCrRTAP6i+UNeTpUJo3/un/vn/qJf9It0TkWnAbCrtxhAeaIk/4//4/+Aye4P+SE/hFaeZyXMIZ+eoMdHLL/JqJT/8X/8H4EwYTz91Kc+tXH5xMzKBVZ0aDMRx6DzwQcfpFX4d/6dfyd5fMUod7UKo0MLE0kQknw9LeBiCS9PopiASuL0n/qn/qmmMOa1AANA/CRzLqj8m2ImCesJ/9w/989J/Tf+jX8j52jVAhoR7Xa+4Ru+wfZG9A//w/9wtWV+UIXqiXrgc+xNbid8+vaTftJP+kE/6Ad9y2/5LankJ0PGGwvFb/WtvhVbpjrykn6v3+v30pn/8//8P/95P+/nWU9rDbMeYAh8ji8vZTNY0T/zz/wz/+q/+q/mPH7ji9AygyIVrc3yRi+CQBgzfiDorAjVGZlFI/6QP+QPMVp91+/6XbP+n/6n/6lR7J/9Z/9Z00uroz/qj/qjfv/f//dXtoZqKgBzLOS8nZUKigAOwDMjJUFQON/n+3wfBWgZZlUmSv7X/Jpf84//4/+4dSDapKpJfPfv/t2pm3CsdmgFlc+FGSXvOVPnTForz7/sL/vLrD/tbXiCacgzEimQC4LoaeWkszLA7Ipa3Kr3+UMFrD7z5V/+5X/On/PniLKoqeuoilGU4msnUR6FdVqKfpMrZrd/4B/4Bwwt2px5QwlOlwqHRIVhIrLHCv4wObROPPV0RdWN4couk5Zp54//4//4yVyIVLL7lJPf8luCzdAAaenq/+q/+q/WSy3//ra/7W9TaqZjtn7f3/f3tWpSZNPV6GUQ2t/39/19sx7ms5FbXnBizquinP/X/rV/Lck3hb/yV/7K7/bdvtsUEcsLkLQu4AkrcLOxET0Ze78f/aN/NM/tN4b2U37KT5HKxB/xR/wRmHAGjlA1OJ4QEhBFPGwJ3+bbfBuAphRNMBW50zesC7JoMLJ8+G//2/9WKsJ4UUmGnMpZF2k9u/BkNNOFS0IY4H7xL/7F1LVgoSnOWHbBGUhWUj9BTk40BHOmgVhj0Ib1lv/pf/qfjJ6GA8+P+TE/BpQkfMSP+3E/7kS4WbkYuy81NvaUMQOYqqLooMboMtBlHkGlKIJAVi7MFIklKUz4d/ldfpe/5q/5a0w1tExl2lYCp6EkhT2S9MOsTEzSaLX4w3/4D1fuydhmKKCljpA1ncSe5Ku/+qurKgX6g3/wDyYwH0LI24qL6VPA8uyf/qf/aWKakZnKYPzOO+/8JX/JX6IC3n//fRwTqX6rCdpQaWHMzQFEmJe8TIBdj4UKD82EP/SH/tA/8U/8E+1rJ4BQYiYfMzlwUfKK4lnvKcDJBPpipdQTKvrP+rP+LLkgbN9vffEVX/EVHPgO3+E76MDmT2sbtoKy5P6r/qq
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=274x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-124907.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"print(save_path)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=128,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=8,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 30,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"end_time": "2023-02-13T14:03:27.628269Z",
"start_time": "2023-02-13T14:00:18.376509Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:23:35.323346Z",
"iopub.status.busy": "2023-02-22T16:23:35.322602Z",
"iopub.status.idle": "2023-02-22T16:27:47.271778Z",
"shell.execute_reply": "2023-02-22T16:27:47.270801Z",
"shell.execute_reply.started": "2023-02-22T16:23:35.323305Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-125042.mp4\n"
2024-04-08 11:37:01 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [02:56<00:00, 5.66it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIiAiIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiNquk6TrH2exg8ny4vssL7d0KMeWQk8knr3rj/APhdvxD/AOhh/wDJK3/+N10HjrSdH1r9oHV7PXtU/s3TykLSTggHi3jwMkEDP0qW4+Hnwtid1/4WAy5wUBZGwMDrhaAOa/4XZ8Q/+hh/8krf/wCN0f8AC7fiH/0MP/klb/8Axuuo0v4F2GuvJJpHjWyvLdeR5UO5gD0yc/0ryvxBod34a1670e9KG4tX2sUOQeMgj8CKAPXPEXxS8ZWPw48F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM964/wD4Xb8Q/wDoYf8AySt//jdHi7/kkPw6/wC4n/6PWuAALEADJPAAoA7/AP4Xb8Q/+hh/8krf/wCN0f8AC7fiH/0MP/klb/8Axuux8I+EdE8H+ApfEvjPQVnvpJR9ht5XJaUEDauzpknJ5zxXMfGyy0+z8X2JsNNi077RpsM8tvEgQK7FuNoGARjmgDpvDvxS8ZX3w48aatc6zvvtN+w/ZJfssI8vzJir8BMHIGOQcdq4/wD4Xb8Q/wDoYf8AySt//jdHhL/kkPxF/wC4b/6PauAAJIAGSemKAPZPBPjr4o+N9abTrLxIkOyJ5XmksYCqgDgH5O5wPxrnpvjP8RYJ5IW8QgtGxUkWdvjIOP8AnnXTa9ep8KPh9b+HNNlX/hINXj86+uEwdkZGNvI9DgfjXixyeTQB7f8AC74o+MvEfxG0rStW1j7RYz+d5kX2aFN22F2HKoCOVB4Ncf8A8Lt+If8A0MP/AJJW/wD8bo+CX/JXtC/7eP8A0nkrz+gD0D/hdvxD/wChh/8AJK3/APjdTSfGL4lRW0Nw+ulYps+Wxsrf5sHB/wCWfrWF4A07w7qniu3t/E96bTT8Fi27aHI6KT2B6cV7Jqvjr4WvPaeGovD8eoWgb7OJYLcDy+cDafvNn1B5oA8y/wCF2/EP/oYf/JK3/wDjddh8Uvil4y8OfEfVdJ0rWfs9jB5PlxfZYX27oUY8shJ5JPWvOviL4Yg8IeNr3SLV3e2QLJEXOWCsAQD9M4rV+Nv/ACV7Xf8At3/9J46AFX41/EV22rr5YnsLK3P/ALTp8nxm+JMShpNcdAehawgGf/Idaml+KvC/hL4f2cvhy2t7nxXO3+kSXlr5rR8nO3JwvbGOver97eeJNW+Fms6p41Ma2spT+yw9oiSeb0O0KBtUgAfhQBzH/C7fiH/0MP8A5JW//wAbrsPEXxS8ZWPw48F6rbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM968Pr0Dxb/ySH4df9xP/ANKFoAP+F2/EP/oYf/JK3/8AjdPh+M/xHuJ44YdeZ5ZGCoi2NuSxPAA/d1w+o6bd6Vd/Zb2IxTbFfaf7rAMD+RFdx8FtBn1r4jWMyJm3sM3E7HsMEAfUkj8jQBY1P4s/FHRr+Sx1HV5LW6jxvilsbcMuRkf8s/Sul8O/FLxnf/Djxpq1zrG++037D9kl+ywjy/MmKvwEwcgY5Bx2rzX4hau2uePdYvW6G4aNfovyj+VbXhL/AJJD8Rf+4Z/6UNQAf8Lt+If/AEMP/klb/wDxuj/hdvxD/wChh/8AJK3/APjdcj4f0hte16y0tJ4oDcyBPNlbCr6kmvVrzwZ8JPDszrqPii+vpYvvwQMregx8q9efWgDl/wDhdvxD/wChh/8AJK3/APjddh8Lfil4y8R/EfStJ1bWftFlP53mRfZYU3bYXYcqgI5APBq/4d0f4R+Lk1CDTfD+oQ/Zbd5Xu5ZZQkagfeJ3kA98Edq4P4MBF+MujLGcoGuQp9R5EmKAGf8AC7fiH/0MP/klb/8Axuj/AIXZ8RDgDxBz/wBeVv8A/G64Dvivo3wD8JNC0PSLLxX4jkcSxwC4eC6wIoeM5Yd8UAedf8Ld+KB6arcf+C6H/wCN1W/4Xb8Q/wDoYf8AySt//jdfTA8b6VJrWi6dDJ5iaxbvPa3Kt8hCjOPrXyP428NT+EvFt9o87hzE+6NwfvI3Kn2JBHFAHqHxS+KXjLw58R9V0nSdY+z2MHk+XF9lhfbuhRjyyEnkk8muP/4Xb8Q/+hh/8krf/wCN0fG3/kr2u/8Abv8A+k8dcv4W0C58TeJLLSbYDfPIAzHoqjkk+2KAOvn+MHxNtoIJp9akjjuFLws1jAA6gkZH7v1BqD/hdvxD/wChh/8AJK3/APjdbvx4so4dS8PzaewfR108W1qysCoMbMCB+G2vIqAPcPEXxS8Z2Pw58F6tbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM964//AIXb8Q/+hh/8krf/AON0eLf+SQ/Dr/uJ/wDo9aqeCJ/AVpaXk3i+1vb25yBb28JZUxjkllIOf8KALf8Awu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N1snV/gzNEynw5qluzcbkuZGK+4y2Kt+IvDfgX/AIVfdeIdI0i+tGMqRWVzdzsGnYk5wucEcHmgC54d+KXjK++HPjTVrnWd99pv2H7JL9lhHl+ZMVfgJg5AxyDjtXH/APC7fiH/ANDD/wCSVv8A/G6PCX/JIfiL/wBw3/0oavP6APVNB+JfxY8Takun6Tqz3FwwyQtlb4UepPl8CotZ+K3xM0LV7jTLzxFEbi3ba/lWtuy5+vl1o/D/AFeXwJ8Lde8SMrCfUJls7H5f4wrZbPXA/pXk15eT395Nd3UhknmYvI56sT3oA9p+F3xS8ZeI/iPpWk6trH2ixn87zIvssKbtsLsOVQEcgHrXH/8AC7fiH/0MP/klb/8Axuj4Jf8AJXtC/wC3j/0RJXn9AHoA+NnxEJAHiDJP/Tlb/wDxutu98e/GXT9I/ta8ubuCwyB58mnwBeen/LOvOvDV9cab4jsLq0sYb66SVfJt5kLK754GARnmvavHk3ifTfhfc23iGWe/1fW7lJDDGhZLKNcHbxwOmPrQB5//AMLt+If/AEMP/klb/wDxuuw+KPxS8ZeHPiPqulaVrH2exg8ny4vs0L7d0KMeWQk8knk14h0OCOa9D+MwjPxm1gTEiIvbbyOoHkR5oAj/AOF2/EP/AKGH/wAkrf8A+N0f8Lt+If8A0MP/AJJW/wD8br0Vfgx4E8Q6XHd+HvEU0YYcO8quPxGAar6D8E/Cg8SiwvvE6ajMsYmFlAQjMvcsQTx9CDQBwX/C7fiH/wBDD/5JW/8A8brsPEXxR8ZWHw58F6rbazsvtS+3fa5fssJ8zy5gqcFMDAOOAM964H
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAIiCAIAAAB61jR9AAEAAElEQVR4Aez9d6C32XrX9fNT7L333hULiiACMk+iooASRaRFRBOBSNEQopgYmTMHooEYUbpiDGJARSMRFRUkmTl2xd6w9957xd/red57Pmed9d37mUHjf7P+WPta1/pcZfVy3/v+/v/+7//7//52n4RPauCTGvikBj6pgU9q4P+bGvj1/r9R+4nWT2rgkxr4pAY+qYFPauBNDTjNCN/+23/7X+/X+/X+f2/CKqYk/q//6//6Y0qOBhh9ErSVBC5kBX3C0KeG6MUXGH9qT6kZwszKqfZUgj4FK1RqowOHmaDkQpqzgnkayo3pLDlA4CVP4qKXXFlwrjBt+Cf9LGyAESfsLVamPMGr9qZtFQV/YZaclUmdPkSfejI9cFmS4zwSRF6y8kbuqaNSdXbmTAM8yy9XfPqWlemcJxf4FFnWRST7qAFsZUkkbTx/1vlLrSTYlCNO+nRsVsLIOnPR8SmMRqRqFqc5zAmYyGosqYmsLKfOkz6dmcURV+7Kkg/0CGEQmPMnwOLcKDekrOVGLz6tvIQZOOLy88x97eLDFAqfFdPdY+4pftInchZP5gmOvsryCEjPtF2AlMsVnjUU88nKjOGWQd2Is79mZkrD1EjDj7h8mpX4L8Hii0dceh6TIeNnZZzXWj7sYdye6ZN/Khwm5pIrMr62v8pyasjcxTmZ6Tw5c+aSetbKJXiKvJ0+BSFLimdlY36FfXT1NDGFrf2S40z/8FlJ4Zgjxj81LDfms1nDIABWlikMMNmVcZwAF/5U+5j19rIkO0OnqtGX9Wf5K4vc8JdOTOFyryT+pKJPWJyMnlbmBuIl/IlBc4m2FBJB5OQb3tO6jn608liWNKcqGmZunApjnkh4zNNKzsRP2+LpPF3FHH/ImRhB5LQyJOJR/My96Ckcn2bM8bOSh8NcxP9Li2xdZXlj/6nVsjV/JE9zo0/A5d6ST1auIiV5xhPAZEAszHDJaqRcWQESjM4KwFl3ZZ3xqTbxaZtOxEkHq+SruAGmJD3xgQuYv81v89v8/J//83/gD/yBv9Vv9VsFhlluzT8lic/KygL/aBHnkTlViMfckzMrIV/r+jCkpNSp8NQp98pa7pk1K2lTkEepOEkFGKwakDXiWfFZmarTh0eRsz5P5ElPFSJ+Vk5ZWbkKgBAQhYl/yHimusJMCQL4KsswI9YrxjkJGs4k+pGDeVnh+aQu+nWpjtzBTgIgKyHnYVaedeAUH/0W5FuyrrJc2iZ4Eq+L9CaczJOmRP6p6lkrK+mJHF0uPdMsK/rkhAf7DX6D3yArMz3wDC1rVh5VLWvEhXm2LAMjHq2cuSd9IZdk8dHK5QZwnEml+UrO3Ck++snKWXHLI6niJNOIWJPImvnoJZ81n5Jni/QsfswExUIlGTG7gSeSFTCcjpyJnIKpSuQ3+81+sx/6Q3/o//F//B//1D/1T33X7/pdUxtAnJ51oDiYK8ul9kQGXjyC3d/it/gt9FechelBCDDCrKAxB0a8QX0Op9zLgYmc4id9nszSeeaedEZTOP6IGRpxZqFXlkuDosU5j0SX7DCQZ9aZjH9ZSfMpHufZOA3TPyLwcml71sqzOj8mM+X5ObuzMs4KcrbyctmKj3MyX6LTNiuSNcFZ3pf8//1//9//+37f7/vee+/92X/2n/07/o6/42DZmsURswKJOb7kCnUpWfIkBp6Gk7isEDxzTz3RcgcYmImFYJKTVcOnlel5iZjgiN/r9/q97Gi/9/f+3uoQMf6IXDqtnA4M9kiA/WF/2B/2M3/mz/xZP+tn/Ql/wp9wAuh8Vslp5cRHn5XzmHty1iGfNZSVp4cov/bX/lqScLjF/9f/9X/NHk7JYAMz8H/+n//nTI4/TjbE46RcsqySaqEkPiUxx0l2ghHZKgsyDUviDMDJ0QHEcX7z3/w3/4Iv+AID7Dt9p+/0e//ev/c//o//4wFoAyimCpPR7GJOybIiVFEiIc/cOATfeecdNfbP/XP/nLUttWDCLFIycIYenYefDyexJjuZaArFSU0W82y78ZO9fJj4iAB/xB/xR3y37/bd/tv/9r/9e/6ev+d/+B/+B7mrtEcNaZ6Gtcs8mQ8vyQ6QqsHw0dN/5l61t+a7+Gme/hGpunIvQ88m51u53+t7fa9/8V/8F/+z/+w/O9sIRsiT4ssuWYBZR+OUrJ4HGGyAOFe83Kv4knHghWjgBUx2Ndkf88f8MX/BX/AX/FF/1B/1+/6+v+83fuM3/qP/6D8aBiACLPDFn8OXzsFWzDi/8+/8O3+H7/Adfuvf+rf+nX6n3+k3+o1+o3/yn/wn/+F/+B8u6zRBraQwPSOyWG60WG6lwy/JWyHmKkEupvjkpPmKYXACR1yAJWF+2A/7YT/8h//w//Q//U//3X/33/0j/8g/Euebv/mb/5f/5X8JIymkME5uT8NgJyYRO9fv9/2+35d8yZcQ+V1/19/1N/lNfpP/4D/4D2j+9/69f48UzKOScS6js4I4SzfY6RVmAbIwtQhZJZ+WGQnCcDKKcSY2RYPpcAYMQFqSnXn4mkduCoOJJxKn3IHDD3wSExwxMM7JnFT+NHFjlsxVeKeKz/u8z7PAyNIk/8V/8V+8UfPZGZxXU9ucSPacJnJgGMkKgrNSn0UDMCn/h//hf/jP/rP/LFpItsqUpDwn32Q+E03t8n7j3/g3tqP8TX/T31SWez/mlOtf+9f+NV15rubDZCNOt9MWPwfkrqIec0P+lr/lb/mjf/SP/hE/4kcYM//8P//P/yv/yr8CedmanwhSZwVC5oM4hYHPGpAlTEkwcdqytdyPQ1wiv8Pv8Dv8oX/oH2oWo9Oi262IbQcf/v6//+9vK0BtLp1OztazTLlzWwP9sX/sH/s1X/M17733nsV4jZLgYNMzIhMcxhkTvmQFkeTblNRLB15NBiCCIzedk1pZcvsl/m/4G/6GBstf8pf8JX/an/an0aO5bcv+q//qv0o8Kfwl40zbCA68ZCjPKfldfpff5Uf9qB/1+Z//+Vqn22xV1zIDI6z4Uzu7xMutqoFhKnJ8yDBLjqgC05lgcZzLVlnTdiJP2RwTG6TW5sIf/8f/8Tg/7af9NNPOt3zLt+QA/ZeJ+DNEpLkCBx14sfH4P/1P/5Phb0/jtPRrfs2vMTB/wS/4BVpqvY5UZT8NnTRAYXUieQKIsy5gRlC+ZGAYBA2fFUR9NvHtvp2RFih10AomXpBcFiYD4QMEftSQFWBjWJzgdL6FMAXY0Ths2tqoO49STjBD2UrnrJQs67R1OvCn/Cl/yr//7//7RFTHX/6X/+W/7W/726b5Ku8pHp2VwM/Gs1tuzij4n/gn/ol//V//13/1V3+1U609mnJhPqsB8yOt2LD8Hr/H7/F9vs/3cVL+Vb/qV/2D/+A/+B/9R/+RBYzgX/fX/XX2mya4S3meYI64rIwfRnKcEbIq4F/0F/1F/9a/9W/RwChz+C+FywrYG8Wvu8FC+sWrvdPiYM8SiTxrJfxUnfrNZdq9wZDsGWugy1ZKZkUyDtiIS0TyD/6D/2CbGFI/+Af/4LNFHkVOzqzQsE4S4LF/Php9iUN22mCyMrsjZJ10Igry9/19fx8R67Hwmc98xp5phuC1wqRO4rJyikQHXrv/nr/n7/lzf+7PTcp+/H/+n/9n9L/wL/wLlpzLsalChD85L9GKc7oabA6fUnMJMzoriZ+5xM/kqWS0yrdwWi9TUvyLf/EvNq2FmQ9lSZ6NtdwpvDh/3B/3x1lXTuUWABOO2cZUM/eSkgxJ26Vn+t9OXFJLjkj8yUp/zObz4ySSWYyIDrMYEb9Y7SBO5op0lipAIvl00sakjvVn/Bl/xnvvvfcv/8v/Mg1WfvRVeCKTOq2A0S8Mjx7yD/wD/8C/4W/4G+D/t//tf/vX//V//bt8l++iE8iFQQz5WvWHYczTiswMzQoO5JIRNhp//p//5//H//F/nOz//r//70r0U37KT/ke3+N7WNdTconMyqO2PPpz/9w/14YlmH23LvU//o//o05sscH8pm/6pj/8D//DL51nUjE
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x546>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-125042.mp4\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = True\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"print(save_path)\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=256,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
2024-04-09 09:31:18 +02:00
"execution_count": 31,
2024-04-08 11:37:01 +02:00
"metadata": {
"ExecuteTime": {
"start_time": "2023-02-13T14:02:13.807Z"
},
"execution": {
"iopub.execute_input": "2023-02-22T16:27:47.273986Z",
"iopub.status.busy": "2023-02-22T16:27:47.273382Z",
"iopub.status.idle": "2023-02-22T16:30:05.576098Z",
"shell.execute_reply": "2023-02-22T16:30:05.575242Z",
"shell.execute_reply.started": "2023-02-22T16:27:47.273949Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"Sampling :: 100%|██████████| 999/999 [01:29<00:00, 11.10it/s]\n"
2024-04-08 11:37:01 +02:00
]
},
{
"data": {
2024-04-09 09:31:18 +02:00
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQgJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAESAiIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP+KXxS8ZeHPiPq2k6TrP2exg8ny4vssL7d0KMeWQk8knk1x//AAu34h/9DD/5JW//AMbo+Nv/ACV7Xf8At3/9ER1vfD34bacmgv4z8ZsYdIiXfDbk4MwHc+3oO9AGD/wu34h/9DD/AOSVv/8AG6P+F2/EP/oYf/JK3/8AjddxD8S/h9qWu2ukxeBrFdOuWET3LW8ayJuOMgBc/rXmnxK8OWvhbx7qWlWTE2yMrxqTkqGUNj8M/lQB6H4i+KPjKx+HHgvVrbWdl7qX277XL9lhPmeXMFTgpgYBxwBnvXH/APC7fiH/ANDD/wCSVv8A/G6PFv8AySH4df8AcS/9HrWj8LfhRd+K7221TVbYroAJJO/a05HGBjnGepoAzv8AhdvxD/6GH/ySt/8A43Tx8ZviSyM6645RerCwgwP/ACHXrXiLRE8HTyzaN8NdK1OxTGyZT5siYHVlbJ615TqPxj16exvNNtdM0jTba4yssVvZgdQQQc559+tAHU+Hfil4yvvhx401a51nffab9h+yS/ZYR5fmTFX4CYOQMcg47Vx//C7fiH/0MP8A5JW//wAbo8Jf8kh+Iv8A3Df/AEe1RfDuX4fxG9PjaC4ndiotljLhQOdxJQg56UAS/wDC7fiH/wBDD/5JW/8A8bo/4Xb8Q/8AoYf/ACSt/wD43XdPdfBOyubaTTNEm1G7d1WK3jeZ8semQzEHnAwc9af8edC8P6b4X0i7s9ItNP1Ga4C7beMR/u9hJBCgA87ecUAUvhb8UfGXiP4j6VpOq6z9osZ/O8yL7LCm7bC7DlUBHIB61x//AAu34h/9DD/5JW//AMbo+CX/ACV7Qv8At4/9ESVwFAHf/wDC7fiH/wBDD/5JW/8A8bo/4Xb8Q/8AoYf/ACSt/wD43Vr4a+ArHV7S78TeJvMi8PWCl2Ibb57L1X1x9MGvRfCviPwbrMGqPb+BdPh8P6TA7yX00CO7KoyByuSx+uaAPMP+F2/EP/oYf/JK3/8Ajddh8Uvil4y8OfEfVtJ0nWfs9jB5PlxfZYX27oUY8shJ5JPWvHdYubW81q9ubK3FtaSzO8MI6RoTwPwFdl8bf+Sva7/27/8ApPHQAf8AC7fiH/0MP/klb/8Axuj/AIXb8Q/+hh/8krf/AON1PHpmleGvhCms3NvHd6t4gaSC1Z1DC2jRsMRno2R1681y/gzw/J4n8W6dpaq/lzTKJXVchV6kmgDof+F2/EP/AKGH/wAkrf8A+N12HiL4o+MrH4ceC9WttZ2X2pfbvtcv2WE+Z5cwVOCmBgHHAGe9ecfES40+bxzqUelWMdlY20n2eKKNNv3PlJP1IJzWv4t/5JD8Ov8AuJf+j1oAP+F2/EP/AKGH/wAkrf8A+N05fjV8RWzt18nHXFjBx/5DrlvC/hnUPFuuwaTpyAyyn5nb7sa9ya9zmPg/4J+HHspooNZ1+5AZ0dAckdM5zsUZ+poA82k+NHxHiYCTXWQkZAaxgH/tOut8O/FLxlffDnxpqtzrO++037D9kl+ywjy/MmKvwEwcgY5Bx2rX1/8As/4tfCN9eit7WLXtLBaRIRjYB1TrnaVwRnuDivOPCX/JIfiL/wBw3/0oagA/4Xb8Q/8AoYf/ACSt/wD43R/wu34h/wDQw/8Aklb/APxuvP66/wADaz4V0J7y88Q6LJqs6hfssO/amc/MW/D1zQBo/wDC7fiH/wBDD/5JW/8A8brsPhb8UvGXiP4j6TpOq6x9osZ/O8yL7LCm7bC7DlUBHKg9avaxo3gbx/8ADjVtb8M6aun6hp6CaVFXaVKgkqR0IIDcj0rgfgl/yV7Qv+3j/wBESUAH/C7fiH/0MP8A5JW//wAbo/4Xb8Q/+hh/8krf/wCN1wABJAAyT2FeqaP8MtI0nwmniXx5f3Vhbzlfs1pbAea4PqCD27DFAGV/wu34h/8AQw/+SVv/APG6P+F2/EP/AKGH/wAkrf8A+N113hDTvhd4w1xdCsPDWorLNFI32qW7f91tBOcA45968d1W1isdXvLSCQyRQzPGjn+IAkA0AeyfFL4o+MvDnxH1XStK1n7PYweT5cX2WF9u6FGPLISeST1rkU+NXxFkYKmvlmPQCxtyT/5Dpvxt/wCSva7/ANu//pPHXS+BYbDwL8NJfHstpBe6pcT+RZJKOIiCVyPfIJoA5+T4x/EuFd0usyovq2nwAf8Aouov+F2/EP8A6GH/AMkrf/43Xtb675Hw4tNZ+IvkXIuyuLDyFCcn5TyCc45yDxXh3xU8I2/hbxTu05lbStQjF1ZlTkBG6j6Z6e2KAO28RfFLxlY/DjwXqttrGy+1L7d9rl+ywnzPLmCpwUwMA44Az3rj/wDhdvxD/wChh/8AJK3/APjdHi3/AJJD8Ov+4n/6ULXHaNJYw6tbyalZy3lorZe3ifY0ntnBxQB20Xxk+JU5xDrUkh6fJYQH+UdLdfGH4mWUvlXWtSQyYztksIFOPxjrtvDmteMteAtvDOhWHhTRYwWkvhbfMiY5JZvvHjriua+MPjXRvEMWm6Xpsg1G4shi41R4wrTHGMAgDjqfSgDZ8O/FLxlffDjxpq1zrO++037D9kl+ywjy/MmKvwEwcgY5Bx2rj/8AhdvxD/6GH/ySt/8A43R4S/5JD8Rf+4Z/6PasLwr4I1nxjJMulJBsgx5sk8oRVzkjk/Q0Abv/AAu34h/9DD/5JW//AMbo/wCF2/EP/oYf/JK3/wDjdaQ8JeA/B8f2jxLrv9uXi9NM01tqk4PDuDkDPcEVJ4x8CaPZfD6PxdHaSaLcXcyLa6YZjIDGe5LZbdjJ6446UAbXwt+KXjLxH8R9J0nVtZ+0WM/neZF9lhTdthdhyqAjkA8GuP8A+F2/EP8A6GH/AMkrf/43R8Ev+SvaF/28f+iJK4CgDv8A/hdvxD/6GH/ySt//AI3R/wALt+If/Qw/+SVv/wDG65Tw74fv/E+uW2k6dEZLidsZ7Ivdj7Ac16Xra/C7wS500aRc6/rFq3l3DvO8cQcdehweeMYoA5//AIXb8Q/+hh/8krf/AON12HxS+KPjLw58RtW0rSdZ+z2MHk+XF9lhfbuhRjyyEnkk8muI8TeP7HW9NfTrHwhounwbVCTRwnz0I9HBGfxFT/G3/kr2u/8Abv8A+iI6AD/hdvxD/wChh/8AJK3/APjdH/C7fiH/ANDD/wCSVv8A/G6x/A3hCXxhrn2ZpPs+n26+de3RIAhiHU89/Su6ttL+GniHXJ/C2h6Xfe
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAESCAIAAABPRClNAAEAAElEQVR4Aez9d8Bv23bX9QfE3nsvYFdEFDtRz84VUdRYULBFwW40CZAIEnLl3n2AJMQgGDQEkQhSbD8japSiufeca1fsFbH3Dvbu77X3e5/PnWd+n2ffG4G/POuP+Yw15hifMWYfc671rO93+X/+n//nCz66PqqBj2rgoxr4qAY+qoFfMzXwXX/NwH6E+lENfFQDH9XARzXwUQ28rgG7mW1ovst3+S7f9bv+Gll4TiufT8Xz5BKbYyMuASqzMvURCbul/mv9Wr/WpesWv9zoJ62ENiskQSV5GRr+c/wJZLfb07GsfE71UyD6TFeKi7nb00rMqeSSW16VFWc04lGYDHlXwqVnjeFcihfIqVjWLJ66Y07mshLOxAab/FuMTvJCCOq08ghOZf4MJ7EnhU/mWWlZ+W7f7bsNJOJye+qI0U9KnjhJSs+yELgQTpUn6Uv+up1KVuR2XUVITNbkE3uSPxnEqeL2ubJMjN3Rqe92RPjnbTTd1Gelxhp/Kjhl4az5TmEmyORA5s6UJMXV2CWWYkZldXuqR+dMKZmIR7GVRRaZTD+K4Txp6LSSFk7MVLp9Y+U0dklP59FSFZe89JQcM6KsRyvTOnVXnpP5JCCByURITyuvsl/33dLLXLV/ClTLc+BJozGzMthL8rqtoi7Y1R6Qsk60mGdZwkxyupcht0O7MAcesVtQj1aCJTOx0/n4OMuNc3r46Nhq7BQmNuSTSMYQnYlHwJk7ZWblOfn4p8qjZLmP6STPGpvYiMQqTswpPtlwZ00mmdisrGZOtJk7mTN0EgSSmfVuk5mVk3mqv53+PLWyUkmnMoKJlfE0dwrgn7fRJ4dAVi7mwJlYDcwK4eRLxz+JfJvYWWOn2yed0WEiRsu6bhPOoqxwrrIM/MSJeXICKZWrvMsdsly0C3FZSbGs0edtzGXNq/FxxkzR7ZsaW8VN+kkitTUV/XkQsdtLPcOXlak8ao0z4gJ87pb8aWXqEUtPz9ErEYHRk0G4ZjGQy8oEZiIiLfRun6RPdQK7Pa0MChFaYkOehwhZT/IvGbfKm5Xk36L1mIVz+jCB+X+aW1mIrZIJBHLpJvAkzjCfzD3LkuQlNkPDOYmEL5nH25Ul3QRmCDEVtOs0gV7uxb9uZ+WUP2nygWMiypJGXIYwJ3wampWT+Z2lBQSabHZP9ZiP7ZJMuVNEjCbQbekwz86TTBxiZ1nCmW7EdMudiauN5F6crMdcWSaDcKWFmKsjMresV9Lf9bteG53TBHpW4s9hioP9VSECvKwEOBNkXG6lshCuX/vX/rXdrs7xyz2JcEqpZOXNrpy0+/JOOv1E0f/X//V/sSH9v//v/3vCEVMfP87Jz6dxRqRy4vCPiWyduZ8nXRECyejAT6MrhVx85UKcRgkACSeZcJaeaOiz6qIxqU8sQpYr6/OBWA4MfARhiq4IfLR0t5mICVARpG5P/tAQcqc7/pOccpcFMPVsra7mT8Uh74oePl2cmHLjd3vSKmG6iJAJjI7Y7fAjyMtCI4BPDBFdemqdjqWe0Uvylf4HbqfulmSVjFPlpLsmOA2h5a7Sriy3AT7y49A9BZQuz7NIZsRJp1LWEE6c58w9x790/8//8/8k+ZZChXNpjTkiD4m51nBniUjWNxArziphOIjnqiU+/KlHX40FM7RBIU4mdYrJjI9p8eiWIbcuMiRd0eEwV6XhN+Qvh0OeOuJEKHfpb/fb/Xbf/bt/d+m//+//+//D//A//Pq//q//y3/5L/9f/pf/JUxpRiePmDMx85YJzmcLLSv3yg2NtwTy+QRZ659+BpKJN+U/WWU8pieE3Aw/iuFckjj5pBkQa5hLl1aKZwEumSdv80Ra+YEkFtqTKpjlSl1ckq755Z6lO+mEUx8I4rmLvKxKFA7OeSv35Jc7Kw2AySBwqHcRo1upkxkfgUOgdHUep1xZEbMiF6dUrmv0iPhPppf66u2yki7hiHJ3SwtnDpOZ6QFGjP8kzqyEf+JcdJJLL9gcG0hETl6SEPClp/ODfY4I5ILKSqYpIlyY8QcVc7cn8Vv/1r/1r/Pr/Dr/2X/2n3Hmqq5pheb2VHyOnlZoz4k98p+zkl25JzF6fJwV4UnTMbOCTrHb05lT96Qnw9AJgg6NQDQ3EiB5IUyX8EV3W5p6FnFcgZ+wTeXhJPlcyofv/b2/95/2p/1p3/f7ft/f4rf4Lf7T//Q//S//y//yN/6Nf+P/3+vrf//f//cp5i1zMySrXioLLavbqeDI2twy/h/5R/6Rf+Af+Af+T//T//RP/BP/xL/z7/w74yOyEmd0Fu9njCu5qYeN0jTPOsK53DrRJxlaWamsnJNZbgLn7WgDxn7tf/1f/9ff4Df4DX7H3/F3/F1+l9/lv//v//v/4r/4L9Ts//a//W+PnrALPxP5IE1sJZpMYtJCjIwSniQ6YVm1SsQrAx/M5sNP3W1EAtE5ME63ssa/2mZZEfmwhr9u5y3heRVx3l5+Ek4+hx/lp5u55H/L3/K31K1/xa/4FTz/nX6n38mM9l/9V/+VPvc//o//49AoroBp7bZihlxWWtLcezJL7hAiLrFBhVPuKRPNRFaIrcJP3akQwz9vT5ru2iL1E42umvk9fo/fQ2ilX5UlxvyP/qP/6D/5T/6TmSMG81Rc1ohyM31Jzp9f99f9dX/X3/V3ZUsw+zv8Dr/Db/ab/WZuf9Pf9Dd97733PvOZz/zKX/kroc0QIrszEfEkczKztVZY1pOKb2cu9+y3AIEvK4tuXRl9NE1ljp00lW4RCYQcE2f8ScpyrYbxM4dD/jI9sXIHuOIspJbVFf7SmFnHjLisYBIrN+Kkv8/3+T5f9VVf9af/6X+6uRFfZzNPun733/1359XP/bk/F7MLbA4HCGp2h78SUUnscgbfGvbX/rV/7Rd/8Rebin/0j/7Rf/Pf/Defi9mgSNKdCbdvlpnZWGEiphnxm/wmv8lv/pv/5qry1/v1fj1jzAXOKDLp/8//8//MZOgXyOzhV4AJ8MB1MlWTocIEKIuKXKlFG/O//W//29/2t/1tf/AP/sF//B//x/93/91/Zwj9pJ/0k4yi1BXhNdhnk7NQs8jnJOBHQDZvsqhnuFgxacpKMvCBDmfqsjDHd3uq8KEsyKbm3+g3+o0qmlsLpNZyCQ1qLYrzeRYRMeVmqyyw6sqKa/Iyc61c5bKCo//9hr/hb6ixmHBRCSSZpXkoyzV6uYiYcXS1v+wv+8t+t9/td7NJV3Vf9EVfJJj61//1f/1v+pv+pp/9s3/2VpocCPOEQld1s3XmzhDPtbvpUkXBVEbFoYiY/JN1Ve5wJjxzCDjdno04yRGBzMqJeRaBh4ZAsOly/g/4A/6Av/Kv/Cv/4r/4L8YfoIXZyPzWb/3W//q//q9repjDTyzHpoKYk6eJiWmC3+q3+q2Y+3P/3D/3t/ltfpvf+/f+vXWJqf85f86f82N+zI/5ju/4DgFZFVjWWZY4A5zu508MbSAjBhJnkuOvdDhyiSFKE54App6sbo0jl/6sM5h2VObQZiUizNXwmOMkkHrm8kHqGn+KcaQLSbnXiCYPdsNwAoR1Y8Pw//g//g/zpFE/5HrOrOQVW3FmFH8qoxn9S/6Sv+QH/aAfxKLY5V/71/61f+6f++dMj9//+39/s+UP/+E/XP38fX/f3xeIlJ8DRLhmBb1KnjMIRhObdZP/H/vH/rGyNMQf+of+ob/n7/l7sksmAWnunbcB3lPkG+7DH6Uys1g84f57/96/Z+QbJ+Bc/+K/+C9a4v6IP+KPeFB6w8gPkhPI+91GYBonL168UEc/8Sf+xL/xb/wbf8kv+SViZIPTLExdhepYCLdaEWG+G8hpBc1hWRGmY/VijTRFmrzM9Vod4bLAvPPOO5/61KeguXSCn/WzftY5UIEENUNJdntaydzERrD++/1+v9+f8Cf8CX//3//3r9KAKMW//C//y1/zNV9jfzbhEacVjVdBpOjJCGT+7X/73/47/o6/Ax+z9BRwaPvjftyP+2f+mX/mL//L/3KFnWLE2SdwlOXEQXd
2024-04-08 11:37:01 +02:00
"text/plain": [
"<PIL.Image.Image image mode=RGB size=546x274>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-09 09:31:18 +02:00
"inference_results/20240408-125343.png\n"
2024-04-08 11:37:01 +02:00
]
}
],
"source": [
"generate_video = False\n",
"\n",
"ext = \".mp4\" if generate_video else \".png\"\n",
"filename = f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}{ext}\"\n",
"\n",
"save_path = os.path.join(log_dir, filename)\n",
"\n",
"\n",
"reverse_diffusion(\n",
" model,\n",
" sd,\n",
" num_images=128,\n",
" generate_video=generate_video,\n",
" save_path=save_path,\n",
" timesteps=1000,\n",
" img_shape=TrainingConfig.IMG_SHAPE,\n",
" device=BaseConfig.DEVICE,\n",
" nrow=16,\n",
")\n",
"print(save_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-04-09 09:31:18 +02:00
"version": "3.11.8"
2024-04-08 11:37:01 +02:00
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}