-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtorch_datasets.py
67 lines (50 loc) · 2.32 KB
/
torch_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations
from typing import Dict, Sequence, Tuple
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
class MyTensorDataset(Dataset[Dict[str, Tensor]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Dict[str, Tensor]
def __init__(self, tensors: Dict[str, Tensor] | Dict[str, np.ndarray], dataset_id: str = ""):
tensors = {
key: torch.from_numpy(tensor.copy()).float() if isinstance(tensor, np.ndarray) else tensor
for key, tensor in tensors.items()
}
any_tensor = next(iter(tensors.values()))
self.dataset_size = any_tensor.size(0)
for k, value in tensors.items():
if torch.is_tensor(value):
assert value.size(0) == self.dataset_size, "Size mismatch between tensors"
elif isinstance(value, Sequence):
assert (
len(value) == self.dataset_size
), f"Size mismatch between list ``{k}`` of length {len(value)} and tensors {self.dataset_size}"
else:
raise TypeError(f"Invalid type for tensor {k}: {type(value)}")
self.tensors = tensors
self.dataset_id = dataset_id
def __getitem__(self, index):
return {key: tensor[index] for key, tensor in self.tensors.items()}
def __len__(self):
return self.dataset_size
def get_tensor_dataset_from_numpy(*ndarrays, dataset_id="", dataset_class=MyTensorDataset, **kwargs):
tensors = [torch.from_numpy(ndarray.copy()).float() for ndarray in ndarrays]
return dataset_class(*tensors, dataset_id=dataset_id, **kwargs)
class AutoregressiveDynamicsTensorDataset(Dataset[Tuple[Tensor, ...]]):
data: Tensor
def __init__(self, data, horizon: int = 1, **kwargs):
assert horizon > 0, f"horizon must be > 0, but is {horizon}"
self.data = data
self.horizon = horizon
def __getitem__(self, index):
# input: index time step
# output: index + horizon time-steps ahead
return self.data[index], self.data[index + self.horizon]
def __len__(self):
return len(self.data) - self.horizon