Minor changes; add train_timestep_fraction
This commit is contained in:
		| @@ -14,18 +14,15 @@ class MLP(nn.Module): | ||||
|         super().__init__() | ||||
|         self.layers = nn.Sequential( | ||||
|             nn.Linear(768, 1024), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.2), | ||||
|             nn.Linear(1024, 128), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.2), | ||||
|             nn.Linear(128, 64), | ||||
|             nn.Identity(), | ||||
|             nn.Dropout(0.1), | ||||
|             nn.Linear(64, 16), | ||||
|             nn.Linear(16, 1), | ||||
|         ) | ||||
|  | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.load_state_dict(state_dict) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward(self, embed): | ||||
|         return self.layers(embed) | ||||
| @@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module): | ||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.mlp = MLP() | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.mlp.load_state_dict(state_dict) | ||||
|         self.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def __call__(self, images): | ||||
| @@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module): | ||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||
|         embed = self.clip.get_image_features(**inputs) | ||||
|         # normalize embedding | ||||
|         embed = embed / embed.norm(dim=-1, keepdim=True) | ||||
|         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||||
|         return self.mlp(embed) | ||||
|   | ||||
| @@ -35,8 +35,6 @@ def aesthetic_score(): | ||||
|     scorer = AestheticScorer().cuda() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if not isinstance(images, torch.Tensor): | ||||
|             images = torch.as_tensor(images) | ||||
|         scores = scorer(images) | ||||
|         return scores, {} | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user