fixed bug with alternate_corr flag
This commit is contained in:
		
							
								
								
									
										30
									
								
								core/corr.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								core/corr.py
									
									
									
									
									
								
							| @@ -60,26 +60,6 @@ class CorrBlock: | |||||||
|         return corr  / torch.sqrt(torch.tensor(dim).float()) |         return corr  / torch.sqrt(torch.tensor(dim).float()) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CorrLayer(torch.autograd.Function): |  | ||||||
|     @staticmethod |  | ||||||
|     def forward(ctx, fmap1, fmap2, coords, r): |  | ||||||
|         fmap1 = fmap1.contiguous() |  | ||||||
|         fmap2 = fmap2.contiguous() |  | ||||||
|         coords = coords.contiguous() |  | ||||||
|         ctx.save_for_backward(fmap1, fmap2, coords) |  | ||||||
|         ctx.r = r |  | ||||||
|         corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) |  | ||||||
|         return corr |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def backward(ctx, grad_corr): |  | ||||||
|         fmap1, fmap2, coords = ctx.saved_tensors |  | ||||||
|         grad_corr = grad_corr.contiguous() |  | ||||||
|         fmap1_grad, fmap2_grad, coords_grad = \ |  | ||||||
|             correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) |  | ||||||
|         return fmap1_grad, fmap2_grad, coords_grad, None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AlternateCorrBlock: | class AlternateCorrBlock: | ||||||
|     def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |     def __init__(self, fmap1, fmap2, num_levels=4, radius=4): | ||||||
|         self.num_levels = num_levels |         self.num_levels = num_levels | ||||||
| @@ -92,20 +72,20 @@ class AlternateCorrBlock: | |||||||
|             self.pyramid.append((fmap1, fmap2)) |             self.pyramid.append((fmap1, fmap2)) | ||||||
|  |  | ||||||
|     def __call__(self, coords): |     def __call__(self, coords): | ||||||
|  |  | ||||||
|         coords = coords.permute(0, 2, 3, 1) |         coords = coords.permute(0, 2, 3, 1) | ||||||
|         B, H, W, _ = coords.shape |         B, H, W, _ = coords.shape | ||||||
|  |         dim = self.pyramid[0][0].shape[1] | ||||||
|  |  | ||||||
|         corr_list = [] |         corr_list = [] | ||||||
|         for i in range(self.num_levels): |         for i in range(self.num_levels): | ||||||
|             r = self.radius |             r = self.radius | ||||||
|             fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) |             fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() | ||||||
|             fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) |             fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() | ||||||
|  |  | ||||||
|             coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() |             coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() | ||||||
|             corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) |             corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) | ||||||
|             corr_list.append(corr.squeeze(1)) |             corr_list.append(corr.squeeze(1)) | ||||||
|  |  | ||||||
|         corr = torch.stack(corr_list, dim=1) |         corr = torch.stack(corr_list, dim=1) | ||||||
|         corr = corr.reshape(B, -1, H, W) |         corr = corr.reshape(B, -1, H, W) | ||||||
|         return corr / 16.0 |         return corr / torch.sqrt(torch.tensor(dim).float()) | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								core/raft.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								core/raft.py
									
									
									
									
									
								
							| @@ -38,11 +38,11 @@ class RAFT(nn.Module): | |||||||
|             args.corr_levels = 4 |             args.corr_levels = 4 | ||||||
|             args.corr_radius = 4 |             args.corr_radius = 4 | ||||||
|  |  | ||||||
|         if 'dropout' not in args._get_kwargs(): |         if 'dropout' not in self.args: | ||||||
|             args.dropout = 0 |             self.args.dropout = 0 | ||||||
|  |  | ||||||
|         if 'alternate_corr' not in args._get_kwargs(): |         if 'alternate_corr' not in self.args: | ||||||
|             args.alternate_corr = False |             self.args.alternate_corr = False | ||||||
|  |  | ||||||
|         # feature network, context network, and update block |         # feature network, context network, and update block | ||||||
|         if args.small: |         if args.small: | ||||||
| @@ -55,7 +55,6 @@ class RAFT(nn.Module): | |||||||
|             self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) |             self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) | ||||||
|             self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) |             self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def freeze_bn(self): |     def freeze_bn(self): | ||||||
|         for m in self.modules(): |         for m in self.modules(): | ||||||
|             if isinstance(m, nn.BatchNorm2d): |             if isinstance(m, nn.BatchNorm2d): | ||||||
| @@ -103,7 +102,7 @@ class RAFT(nn.Module): | |||||||
|         fmap1 = fmap1.float() |         fmap1 = fmap1.float() | ||||||
|         fmap2 = fmap2.float() |         fmap2 = fmap2.float() | ||||||
|         if self.args.alternate_corr: |         if self.args.alternate_corr: | ||||||
|             corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) |             corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) | ||||||
|         else: |         else: | ||||||
|             corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) |             corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user