【科研03】【代码复现】TransUnet道路提取

发布时间 2023-10-15 08:09:14作者: 安知燕雀?

1. 数据准备 data process

  经过科研02部分的数据预处理,我们已经得到了以下内容:

  • train-image & train-label --> train.npz, train.txt
  • test-image & test-label --> test.npz, text.txt
  • val-image & val-label --> val.npz, val.txt

2. 文件更名 files rename

  已经确认了数据内容的正确。

  在复现代码时,尽量不更改代码中的文件名字,因此接下来还需要将这些数据更改为原始代码中使用的文件名。

2.1. 数据更名 npz rename

  05_npz_files文件夹更名为Synapse,该文件夹放在data文件夹中。

  其下有两个文件夹:

  • train_npz文件夹名不做变更。

  • test_npzval_npz两个文件夹中的内容合并,并更名为test_vol_h5

  不用担心test_npz和val_npz混淆在一起,代码会通过txt文档来筛选。

2.2. 文档更名 txt rename

  06_npzFiles_txt文件夹更名为lists_Synapse

  其下有三个文件:

  • train.txt文件名不做更改。

  • test.txtval.txt中的内容合并在一起,更名为test_vol.txt

3. 代码修改 code change

  大多还是依据TransUnet官方代码训练自己数据集中的内容进行的修改。

3.1. 目录调整 contents

  目录安排并未按照上述csdn的链接,data、model、predictions和TransUNet-main文件夹是同级别的。

  TransUNet-main下包含的文件夹包括:networks、datasets、test_log和lists,文件包括trainer.py,test.py,utils.py和train.py。

  可以按照上述内容做一下核查。

3.2. 数据读取 code1

  • TransUNet-main-> datasets -> dataset_synapse.py

  按照TransUnet官方代码训练自己数据集中的内容进行修改。

  dataset_synapse.py文件中的Synapse_dataset类中,修改__getitem__函数如下:

 def __getitem__(self, idx):
        if self.split == "train":
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
        else:
            slice_name = self.sample_list[idx].strip('\n')
            data_path = self.data_dir+"/"+slice_name+'.npz'
            data = np.load(data_path)
            image, label = data['image'], data['label']
            image = torch.from_numpy(image.astype(np.float32))
            image = image.permute(2,0,1)
            label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        sample['case_name'] = self.sample_list[idx].strip('\n')
        return sample

  dataset_synapse.py文件中的RandomGenerator类,修改__call__函数如下:

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        x, y,_ = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3)  # why not 3?
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = torch.from_numpy(image.astype(np.float32))
        image = image.permute(2,0,1)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long()}
        return sample

3.2. 训练参数 parameter set

  主要是train.py文件中的参数。

  而为尽可能的保证不修改源代码中的内容,博主尽可能的保证文件夹的位置与github中的设置一致。

  TransUnet官方代码训练自己数据集中提到的../内容修改为./的部分都未进行修改。

3.2.1. 目标类别 num classes

  train.py文件。

  下面代码中的default依据label中的物体类别来定义。

  约在第18行。

  • 建筑物识别:类别1是背景,类别2是建筑物,default=2

  • 道路识别:类别1是背景,类别2是建筑物,default=2

  • 土地利用覆盖分类:类别1是大豆,类别2是小麦,类别3是水稻,类别4是其他,default=4

  • ···

parser.add_argument('--num_classes', type=int, default=2, help='output channel of network')

3.2.2. 运行轮次 max epochs

  train.py文件。

  下面代码中的default设定为想要运行多少个epoch来决定。

  约在第22行。

parser.add_argument('--max_epochs', type=int, default=150, help='maximum epoch number to train')

3.2.3. 批次传入 batch size

  train.py文件。

  下面代码中的default设定一个iteration传入多少个image和label。

  • 对于512 X 512大小的image和label,16GB的显存,可以设定batch_size为4,如果显存为12GB的话,应该可以设定为2,如果显存为8GB及以下,建议设定为1吧,不然电脑会卡。

  约在第24行。

parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')

3.2.4. 图片尺寸 image size

  train.py文件。

  因处理的图片是512大小的,故而设定default=512。

  约在第31行。

parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input')

3.2.5. 数据设置 dataset_config

  train.py文件。

  一些不是参数的部分,但需要修改。

  因为数据集的设置,所以只需要更改下面代码中的num_classes为2。

这里其实只是修改了参数部分的默认值?

  约在第60行。

    dataset_config = {
        'Synapse': {
            'root_path': '../data/Synapse/train_npz',
            'list_dir': './lists/lists_Synapse',
            'num_classes': 2,
        },
    }

3.2.6. 保存名称 save name

  train.py文件。

建议每次都对下面这些内容进行修改,确保生成的模型文件是在一个新的文件夹中,而不会覆盖前一次的模型训练结果。

  如果不修改,并且也没修改max_epochs,那么会无情的覆盖上一次的模型结果,导致无法在必要时调用不同阶段训练的模型文件对结果进行测试。

  建议将两行的TU都设定成任务名_日期,如RoadExtract_231013

  • 道路提取任务:RoadExtract

  • 日期:23年10月13日

  约在第67行。

    # 修改前
    args.exp = 'TU_' + dataset_name + str(args.img_size)
    snapshot_path = "../model/{}/{}".format(args.exp, 'TU')
    # 修改后
    args.exp = 'RoadExtract_231013_' + dataset_name + str(args.img_size)
    snapshot_path = "../model/{}/{}".format(args.exp, 'RoadExtract_231013')

3.2.6. 重要修改 important set

  trainer.py文件。

  如果不修改,会出现电脑自动重启or程序莫名中断等问题。

  约在第33行。

    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, worker_init_fn=worker_init_fn)

4. 实际实验 real experiment

  train-image & train-label : 8660

  16GB显存,batch-size设置为4,epochs设置为150。

  预计运行时间:1week!!!!!

  这也太慢了!早知道设置epochs设置小一点了。

5. 改进想法 think

5.1. 权重保存 parameter save

  不能保存每一次训练的权重。

  这确实减少了存储空间的损耗,但是导致不训练够100个epoch或训练完所有的epoch,就没有权重文件被保存下来。

  如果在即将训练完时电脑断电,那么可能一周的结果就徒然无功了。

  不说保存每一个epoch,起码应当考虑每5个epochs之类的保存一下权重。

5.2. 权重名称 parameter name

  权重的名字保存下来是无意义的epoch_99.pth或epoch_149.pth,或许可以将权重的名字更改为含有精度评价指标的名字,如:epoch_49_miou-72.33_oa-88.75_···.pth。


  后续有机会将针对这些内容做一些修改。