熵编码实现

发布时间 2023-12-01 20:25:16作者: 浪矢-CL
 1     def compress(self, x):
 2         y = self.g_a(x)
 3         y_strings = self.entropy_bottleneck.compress(y)
 4         return {"strings": [y_strings], "shape": y.size()[-2:]}
 5 
 6     def decompress(self, strings, shape):
 7         assert isinstance(strings, list) and len(strings) == 1
 8         y_hat = self.entropy_bottleneck.decompress(strings[0], shape)
 9         x_hat = self.g_s(y_hat).clamp_(0, 1)
10         return {"x_hat": x_hat}
class EntropyBottleneck(EntropyModel):

   def compress(self, x):
        indexes = self._build_indexes(x.size())
        medians = self._get_medians().detach()
        spatial_dims = len(x.size()) - 2
        medians = self._extend_ndims(medians, spatial_dims)
        medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
        return super().compress(x, indexes, medians)

    def decompress(self, strings, size):
        output_size = (len(strings), self._quantized_cdf.size(0), *size)
        indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
        medians = self._extend_ndims(self._get_medians().detach(), len(size))
        medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
        return super().decompress(strings, indexes, medians.dtype, medians)

 

class _EntropyCoder:

底层代码中常用的是非对称数系编码区间编码
然后使用index进行编码/解码
将概率质量函数转换为量化的累积分布函数,并定义了一个占位符方法,鼓励在子类中提供具体实现。

class EntropyModel(nn.Module):