常用的API例子
1 tensor转为numpy
ps: gpu下的tensor不能直接转numpy,需要先转到cpu tensor后再转为numpy.cpu().numpy()
1.1 tensor.numpy()
x = torch.rand(6).view(2,3).type(torch.float32)
print(type(x))
x_array = x.numpy()
print(x_array,type(x_array))
output:
<class 'torch.Tensor'>
[[0.9542696 0.8235684 0.6300868 ]
[0.16127479 0.40761203 0.22885096]] <class 'numpy.ndarray'>
2 numpy转为tensor
2.1 torch.tensor(x)
x = np.array(3)
print(type(x))
x = torch.tensor(x)
print(type(x))
output:
<class 'numpy.ndarray'>
<class 'torch.Tensor'>
2.2 torch.as_tensor()
x = np.ones(5)
print(type(x))
x = torch.as_tensor(x,dtype=torch.float32)
print(x,type(x))
output:
&l