Update GeMOSA v4
This commit is contained in:
parent
b6e11c6360
commit
16861f0f3d
@ -127,6 +127,10 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
targets = torch.from_numpy(targets)
|
targets = torch.from_numpy(targets)
|
||||||
else:
|
else:
|
||||||
targets = torch.Tensor(targets)
|
targets = torch.Tensor(targets)
|
||||||
|
if dataset.dtype == torch.float64:
|
||||||
|
dataset = dataset.float()
|
||||||
|
if targets.dtype == torch.float64:
|
||||||
|
targets = targets.float()
|
||||||
return torch.Tensor([timestamp]), (dataset, targets)
|
return torch.Tensor([timestamp]), (dataset, targets)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user