Update GeMOSA v4
This commit is contained in:
parent
b6e11c6360
commit
16861f0f3d
@ -127,6 +127,10 @@ class SyntheticDEnv(data.Dataset):
|
||||
targets = torch.from_numpy(targets)
|
||||
else:
|
||||
targets = torch.Tensor(targets)
|
||||
if dataset.dtype == torch.float64:
|
||||
dataset = dataset.float()
|
||||
if targets.dtype == torch.float64:
|
||||
targets = targets.float()
|
||||
return torch.Tensor([timestamp]), (dataset, targets)
|
||||
|
||||
def __len__(self):
|
||||
|
Loading…
Reference in New Issue
Block a user