From 16861f0f3dee1b7fc9bc00bd8b286dd6f9ea4af3 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 17:46:00 +0800 Subject: [PATCH] Update GeMOSA v4 --- xautodl/datasets/synthetic_env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index aaa1b98..077e826 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -127,6 +127,10 @@ class SyntheticDEnv(data.Dataset): targets = torch.from_numpy(targets) else: targets = torch.Tensor(targets) + if dataset.dtype == torch.float64: + dataset = dataset.float() + if targets.dtype == torch.float64: + targets = targets.float() return torch.Tensor([timestamp]), (dataset, targets) def __len__(self):