Numpy中数据类型转换的tips

发布时间 2023-12-25 11:02:04作者: 絵守辛玥

在逛Stack Overflow时看见一个关于numpy的浮点数据转换的问题比较有趣,现当作tips记录下来。[问题原地址](python - Convert numpy array type and values from Float64 to Float32 - Stack Overflow)

我们知道,在numpy中,浮点数据同python本身一样,是用双精度(float64)来存储数据的,而Pytorch或者其他的一些框架中,为了节省运算量,其浮点是用单精度(float32)来存储数据的,因此需要用到数据转换。下例是一个数据转换的操作。

l = np.array([0., 0., 0.])

for i in range(len(l)):
    l[i] = l[i].astype(np.float32)

print(l[0].dtype)

# >>> float64

在该部分代码中,ndarray数组l的数据默认为float64,于是遍历其中每一个元素将其更改为float32。但是在更改完后,取出数组中的第0个值发现其数据类型仍为float64。为什么会造成这种原因,我们用下列例子演示一下:

l = np.array([0., 0., 0.])
a = l[0]
b = l[1]
c = l[2]

a = a.astype(np.float32)
b = b.astype(np.float32)
c = c.astype(np.float32)
print("a的数据类型:", a.dtype)


l[0] = a
l[1] = b
l[2] = c

print("数组中a的数据类型:", l[0].dtype)

# >>> a的数据类型: float32
# 	  数组中a的数据类型: float64

这里从列表l中取出了第一个值赋值给a,改变a的数据类型为float32,打印后a的数据类型的确变为了float32。但是再将其放入数组中后,它的数据类型又变回了float64。因此我们可以发现,数组本身的数据类型决定了数组中每一个元素的数据类型。数组l的数据类型为float64,其中每个元素的数据类型都为float64,即使更改了所有元素的数据类型。可以用下列例子验证。

l = np.array([0., 0., 0.])
l = np.float32(l)  # 将数组更改为float32
a = l[0]
b = l[1]
c = l[2]

a = a.astype(np.float64)
b = b.astype(np.float64)
c = c.astype(np.float64)
print("a的数据类型:", a.dtype)


l[0] = a
l[1] = b
l[2] = c

print("数组中a的数据类型:", l[0].dtype)

# >>> a的数据类型: float64
#     数组中a的数据类型: float32