65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
|
import dataclasses
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
|
||
|
"""Change the device of object recursively"""
|
||
|
if isinstance(data, dict):
|
||
|
return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
|
||
|
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
|
||
|
return type(data)(
|
||
|
*[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
|
||
|
)
|
||
|
# maybe namedtuple. I don't know the correct way to judge namedtuple.
|
||
|
elif isinstance(data, tuple) and type(data) is not tuple:
|
||
|
return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
|
||
|
elif isinstance(data, (list, tuple)):
|
||
|
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
|
||
|
elif isinstance(data, np.ndarray):
|
||
|
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
|
||
|
elif isinstance(data, torch.Tensor):
|
||
|
return data.to(device, dtype, non_blocking, copy)
|
||
|
else:
|
||
|
return data
|
||
|
|
||
|
|
||
|
def force_gatherable(data, device):
|
||
|
"""Change object to gatherable in torch.nn.DataParallel recursively
|
||
|
|
||
|
The difference from to_device() is changing to torch.Tensor if float or int
|
||
|
value is found.
|
||
|
|
||
|
The restriction to the returned value in DataParallel:
|
||
|
The object must be
|
||
|
- torch.cuda.Tensor
|
||
|
- 1 or more dimension. 0-dimension-tensor sends warning.
|
||
|
or a list, tuple, dict.
|
||
|
|
||
|
"""
|
||
|
if isinstance(data, dict):
|
||
|
return {k: force_gatherable(v, device) for k, v in data.items()}
|
||
|
# DataParallel can't handle NamedTuple well
|
||
|
elif isinstance(data, tuple) and type(data) is not tuple:
|
||
|
return type(data)(*[force_gatherable(o, device) for o in data])
|
||
|
elif isinstance(data, (list, tuple, set)):
|
||
|
return type(data)(force_gatherable(v, device) for v in data)
|
||
|
elif isinstance(data, np.ndarray):
|
||
|
return force_gatherable(torch.from_numpy(data), device)
|
||
|
elif isinstance(data, torch.Tensor):
|
||
|
if data.dim() == 0:
|
||
|
# To 1-dim array
|
||
|
data = data[None]
|
||
|
return data.to(device)
|
||
|
elif isinstance(data, float):
|
||
|
return torch.tensor([data], dtype=torch.float, device=device)
|
||
|
elif isinstance(data, int):
|
||
|
return torch.tensor([data], dtype=torch.long, device=device)
|
||
|
elif data is None:
|
||
|
return None
|
||
|
else:
|
||
|
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
||
|
return data
|