CRNN推理部分解析

发布时间 2023-08-17 13:49:27作者: 周而輹始


class resizeNormalize(object):
    def __init__(self, size, interpolation=Image.LANCZOS, is_test=True):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()
        self.is_test = is_test

    def __call__(self, img):
        w, h = self.size
        w0 = img.size[0]
        h0 = img.size[1]
        if w <= (w0 / h0 * h):
            img = img.resize(self.size, self.interpolation)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
        else:
            w_real = int(w0 / h0 * h)
            img = img.resize((w_real, h), self.interpolation)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
            tmp = torch.zeros([img.shape[0], h, w])
            start = random.randint(0, w - w_real - 1)
            if self.is_test:
                start = 0
            tmp[:, :, start:start + w_real] = img
            img = tmp
        return img

实现了一个图片预处理函数`resizeNormalize`,用于将输入的图片调整大小并进行归一化处理。

函数参数说明:
- `size`:调整后的图片大小,类型为元组`(width, height)`
- `interpolation`:调整图片大小时使用的插值方法,默认为`Image.LANCZOS`
- `is_test`:是否为测试模式,默认为`True`

函数主要步骤:
1. 获取输入图像的原始宽度`w0`和高度`h0`
2. 检查调整后的宽度是否小于等于`(w0 / h0) * h`,如果满足条件,则直接将图像调整为目标大小,并进行归一化处理。
3. 如果上述条件不满足,则计算调整后的实际宽度`w_real`,然后将图像宽度调整为`w_real`,高度调整为目标高度`h`,并进行归一化处理。
4. 创建一个大小为`[img.shape[0], h, w]`的零张量`tmp`,其中`img.shape[0]`表示图像通道数。
5. 在`tmp`张量中随机选择一个起始位置`start`,将调整后的图像`img`插入到`tmp`中,使其在水平方向上居中。
6. 最后返回处理后的图像。

注意:
- `transforms`是PyTorch中的一个模块,用于进行图像的数据变换操作。在该代码中使用了`ToTensor`变换,将图像转换为Tensor类型,并进行归一化处理。
- 归一化处理是的操作是将每个像素值从[0, 255]范围映射到[-1, 1]范围。具体地,对于每个像素值,都执行`(value - 0.5) / 0.5`的计算。

w0 = img.size[0]h0 = img.size[1]

获取输入图像的宽度和高度。

- `img.size`返回一个包含图像宽度和高度的元组。
- `img.size[0]`表示图像的宽度,用变量`w0`保存。
- `img.size[1]`表示图像的高度,用变量`h0`保存。

img.sub_(0.5).div_(0.5)

`img.sub_(0.5).div_(0.5)`是对图像`img`进行归一化处理的操作。

- `img.sub_(0.5)`用于将图像的每个像素值减去0.5,即对每个像素值进行减去0.5的操作。
- `img.div_(0.5)`用于将图像的每个像素值除以0.5,即对每个像素值进行除以0.5的操作。

这两个操作将图像的像素值从原来的范围[0, 1]映射到了[-1, 1],实现了图像的归一化处理。这种归一化操作可以使得图像在模型训练或测试中具有更好的数值稳定性和收敛性。

img.shape[0]

`img.shape[0]`是获取张量`img`的第一个维度的大小,即图像的通道数。

在此代码段中,`img`表示调整后的图像,它是一个三维张量,具有形状`(C, H, W)`,其中`C`是通道数,`H`是图像的高度,`W`是图像的宽度。

通过`img.shape[0]`可以获取图像的通道数,用于创建全零张量`tmp`的大小。

 


 

创建全零张量的意义在于初始化一个与调整后的图像具有相同形状的张量,并将其所有元素的值都设置为零。

在这段代码中,我们创建了一个大小为`[img.shape[0], h, w]`的全零张量`tmp`。这个张量将被用作目标图像的容器,用于插入调整后的图像。

将张量元素初始化为零是为了确保在后续的操作中,没有调整大小的区域保持为零。这有助于避免对原始图像的影响,并确保调整后的图像与所需大小的一致。


 

class strLabelConverter(object):
    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '_'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

 

上述代码是一个名为`strLabelConverter`的类,用于将字符串标签转换成对应的索引标签。以下是该类的主要功能和作用:

- `alphabet`参数表示要转换的字符集,即所有可能的字符。例如,可以是'a-z'的小写字母组成的字符串。

- `ignore_case`参数用于指定是否忽略字符的大小写,默认为`False`。

- `self.dict`是一个字典,用于将每个字符映射到对应的索引。索引从1开始,0被保留用于表示特殊的'blank'字符。

该类的主要用途是提供一个方便的方法来将字符串标签转换成索引标签,以便在一些自动语音识别(ASR)或光学字符识别(OCR)等任务中使用。通过使用该类,可以将字符标签转换为模型所需的数字索引标签,以便进行后续的计算和训练。

class strLabelConverter(object):
    '''
    将字符串标签转换成对应的索引标签
    '''
    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '_'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    # print(self.dict)
    def encode(self, text):
        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')
            length.append(len(item))
            for char in item:
                if char not in self.dict.keys():
                    index = 0
                else:
                    index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

定义了一个名为`strLabelConverter`的类,用于将字符串标签转换成对应的索引标签。该类包含以下方法:

- `__init__(self, alphabet, ignore_case=False)`:类的初始化方法,接受两个参数,`alphabet`表示要转换的字符集,`ignore_case`表示是否忽略大小写。在初始化方法中,根据`ignore_case`的值来设置`alphabet`的大小写,并在`self.dict`字典中建立字符到索引的映射关系。

- `encode(self, text)`:将给定的字符串标签转换为索引标签的方法。它接受一个字符串列表或字符串数组作为输入,并返回一个包含索引标签的Tensor和每个标签长度的Tensor。

- `decode(self, t, length, raw=False)`:将给定的索引标签转换为字符串标签的方法。它接受索引标签Tensor `t`和标签长度Tensor `length`作为输入,并返回相应的字符串标签。如果输入是批处理数据,则将返回一个字符串列表。

该类的作用是在自动语音识别(ASR)或光学字符识别(OCR)等任务中,将字符标签转换为模型所需的索引标签,或将索引标签转换为字符标签,便于使用和处理这些标签数据。


将给定的字符串标签转换为索引标签的意义在于:

1. 模型训练和推断:在使用深度学习模型进行任务训练和推断时,通常需要将输入数据转换为数字形式(例如整数索引),以便模型能够理解和处理。将字符串标签转换为索引标签可以方便地将字符标签表示为模型所需的输入格式。

2. 序列标注:在一些自然语言处理任务中,如命名实体识别、词性标注等,需要对输入文本的每个字符进行标注,即将每个字符标记为特定的类别。将字符串标签转换为索引标签可以将字符标签映射到对应的索引,便于进行序列标注。

3. 词汇表构建:将字符串标签转换为索引标签可以方便地构建词汇表,其中每个字符对应一个唯一的索引。词汇表可以用于模型的输入表示,例如使用独热编码将索引标签表示为向量形式。


 

 

class PytorchOcr():
    def __init__(self, model_path):
        alphabet_unicode = config.alphabet_v2
        self.alphabet = ''.join([chr(uni) for uni in alphabet_unicode])
        # print(len(self.alphabet))
        self.nclass = len(self.alphabet) + 1
        self.model = CRNN(config.imgH, 1, self.nclass, 256)
        self.cuda = False
        if torch.cuda.is_available():
            self.cuda = True
            self.model.cuda()
            self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
        else:
            # self.model = nn.DataParallel(self.model)
            self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        self.model.eval()
        self.converter = strLabelConverter(self.alphabet)

    def recognize(self, img):
        h,w = img.shape[:2]
        if len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        image = Image.fromarray(img)
        transformer = resizeNormalize((int(w/h*32), 32))
        image = transformer(image)
        image = image.view(1, *image.size())
        image = Variable(image)

        if self.cuda:
            image = image.cuda()

        preds = self.model(image)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)

        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        txt = self.converter.decode(preds.data, preds_size.data, raw=False)

        return txt

一个用于OCR(Optical Character Recognition,光学字符识别)任务的PyTorch模型类。下面是代码的主要功能:

- `__init__(self, model_path)`:初始化方法。接收一个`model_path`参数,用于指定OCR模型的路径。在方法内部,首先根据提供的配置文件`config`中的`alphabet_v2`创建字符集`self.alphabet`。然后初始化OCR模型`self.model`,并根据是否有可用的CUDA设备进行模型加载。最后,创建字符标签转换器`self.converter`,将字符集传递给转换器。

- `recognize(self, img)`:识别方法。接收一个图像`img`作为输入,执行OCR任务。在方法内部,首先获取输入图像的宽度和高度,并转换为灰度图像(如果原始图像是彩色图像)。然后,根据转换后的图像尺寸创建一个可以输入到OCR模型的图像张量。将图像张量转为PyTorch的`Variable`类型,并将其传递给OCR模型进行推断。通过模型的输出得到预测结果`preds`,并进行解码,将索引标签转换为字符串标签。最后,返回识别到的文本结果`txt`。

该代码使用了CRNN(Convolutional Recurrent Neural Network)模型,通过图像的卷积和循环神经网络结构实现文本识别。同时,为了处理不同尺寸的输入图像,使用了图片归一化和尺寸调整的预处理步骤。通过这个类,可以方便地加载已训练好的OCR模型并进行文本识别。