Point Cloud Segmentation using DGCNN

프로젝트 소개

png

DGCNN 모델 아키텍처를 이용한 3D Point Cloud Segmentation

요약

Point cloud 데이터는 주로 3D space에서 어떤 object 표면 위에 존재하는 점들의 집합을 의미합니다. 드론이나 자율주행자동차에 달려있는 Lidar 센서를 통해 수집되는 데이터가 이러한 point cloud 형태로 존재합니다. Point cloud는 일정한 간격의 grid 형태로 존재하는 2D 이미지 데이터와는 달리 공간 안의 점들이 여기저기 흩어져 있고, 점들의 density 역시 일정하지 않습니다. 또한 point들이 어떤 순서로 있던지 간에 형태는 항상 일정하게 유지된다는 점에서 점에서 permutation invariant 하다는 특성이 있습니다.

따라서 기존 2D 형태 학습에 최적화된 grid convolution 방식을 그대로 적용할 수는 없습니다. 2017년 CVPR에서 발표된 PointNet[2] 논문을 기점으로 point cloud 데이터에 deep learning을 적용한 연구 결과가 발표되고 있는데, 이번 프로젝트의 DGCNN 역시 PointNet의 구조를 응용한 것입니다. DGCNN의 특징을 요약하면 아래와 같습니다.

  • EdgeConv라는 연산을 새롭게 추가하여 permutation invariance를 유지하며 포인트들의 local feature을 학습할 수 있도록 하였습니다.
  • 레이어마다 포인트들의 관계 그래프를 dynamic하게 다시 구성하여 의미있는 grouping을 학습할 수 있도록 하였습니다. 이 논문에서 제시하고 있는 모델 아키텍처의 이름이 DGCNN(Dynamic Graph CNN)인 이유가 바로 여기에 있습니다.

DGCNN이 PointNet 기반으로 만들어졌기 때문에, PointNet 모델 구조를 먼저 살펴볼 필요가 있습니다.

PointNet

png PointNet model architectures

PointNet은 input points가 들어오면 permutation-invariant 연산인 mlp와 max pooling 을 사용하여 feature을 추출하는 것이 특징입니다. 여기서 feature transform에 해당하는 T-Net은 학습되는 point들의 feature들이 특정한 변환에 불변한 형태로 바꾸어주는 역할을 합니다. 예를 들어 point 전체를 회전한다고 하여 global point cloud의 카테고리나 point들의 segmentation 결과가 달라져서는 안되므로, T-Net의 affine transformation을 통해 이 문제를 해결했습니다. T-Net 역시 전체 모델 구조와 유사하게 mlp와 max pooling으로만 구성되어 있습니다. 이러한 T-Net의 아이디어는 Spatial Transformer Networks[3]의 개념과 비슷합니다. Segmentation 네트워크 구조에서는 global feature까지 고려하기 위해 위 그림에서와같이 이를 결합하여 입력으로 사용하는 것을 확인할 수 있습니다.

PointNet은 낱개 입력 point 하나하나의 좌표에 대한 연산을 하여 추출한 global feature를 사용하기 때문에 어떤 point 주변에 모여있는 local feature는 얻어낼 수 없다는 문제가 있습니다. 이후 저자들은 PointNet의 개선판인 PointNet++ 으로 local feature를 고려하는 방식을 제안하였지만 여기서는 다루지 않겠습니다.

DGCNN

png DGCNN model architectures

DGCNN의 모델 구조입니다. PointNet 모델 구조와 비슷하지만 EdgeConv가 새롭게 적용되었습니다. EdgeConv 연산이 적용될 때, k-NN을 사용하여 입력 그래프를 매번 새롭게 생성해줍니다. 따라서 첫 번째 EdgeConv에서는 input space에서 물리적으로 가까이 있는 k 개의 point들을 뽑아서 사용한다면, 그 다음 EdgeConv부터는 feature space에서 distance가 가까운 k 개의 point들을 새롭게 뽑아서 사용하게 됩니다.


png Point cloud segmentation using the proposed neural network

위 그림에서처럼 입력으로 들어온 point들 중에서 기준점(빨간색)을 찍어보면, input space 상에서 euclidean distance가 가까운 점들이 선택됩니다. 이 과정을 반복하여 뒷단으로 갈수록 feature space 상에서 가까운 점들을 추출하고 있는 것을 확인할 수 있습니다. 아래 코드에서 feature를 새롭게 추출해내는 부분은 get_graph_feature 함수로 구현되어 있습니다.

논문에서는 총 16,881개의 3D 데이터를 사용하였고, 각 데이터는 2048개의 points들로 샘플링 되었습니다. 총 16개의 object categories(비행기, 의자, 자동차 등)가 있고, 각 object별로 최대 6개 part(ex. 비행기: 날개, 몸통, 꼬리, 엔진)로 포인트들이 레이블링 되어져 있습니다. 아래 구현 부분에서는 빠른 실험을 위해 약 17,000개 중 2048개의 3D 데이터만을 사용하여 훈련하였습니다. Test셋 역시 일부만을 추출하여 사용하였고, 그 결과 81.17%의 mIoU를 보여주었습니다. 훨씬 적은 데이터셋을 사용하였기에 논문(85.2% mIoU)보다는 낮았습니다.

최근 인턴으로 일하고 있는 회사에서 3D 치아 데이터를 다루고 있어 이번 포스팅을 작성해보았습니다. 딥러닝을 적용하여 치아 데이터를 각각의 치아별로 segmentation해주는 프로젝트에 참여하게 되어 point cloud 및 DGCNN 모델 구조에 대해 자연스럽게 공부하게 되었습니다. 다음에는 이러한 point cloud에 적용될 수 있는 다른 모델 아키텍쳐들(ex. PointCNN, RandLA-Net 등)을 공부해 포스팅해보겠습니다.


Packages

필요한 패키지들을 불러옵니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import numpy as np 
import pandas as pd
from collections import Counter
import h5py
import json
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from easydict import EasyDict
from tqdm import tqdm

import plotly.offline as pyo
from plotly.offline import init_notebook_mode
import plotly.graph_objects as go
pyo.init_notebook_mode()

Arguments

하이퍼파라미터 값들이 저장된 딕셔너리를 생성합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
config = EasyDict({
    'TRAIN_DATA_PATH' : '../input/shapenetpart/shapenetpart_hdf5_2048/train0.h5',
    'VALID_DATA_PATH' : '../input/shapenetpart/shapenetpart_hdf5_2048/val0.h5',
    'TEST_DATA_PATH' : '../input/shapenetpart/shapenetpart_hdf5_2048/test0.h5',
    'DEVICE' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'BATCH_SIZE' : 16,
    'EPOCHS' : 200,
    'LEARNING_RATE' : 0.001,
    'DROP_OUT' : 0.5,
    'EMB_DIMS' : 1024,
    'SEG_NUM_ALL' : 50,
    'WEIGHT_DECAY' : 0.0001,
    'K' : 40
})
1
2
3
4
5
shapenetpart_cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 
                       'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 
                       'motorbike': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15}
shapenetpart_seg_num= [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
shapenetpart_seg_start_index = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]

데이터 불러오기

필요한 데이터를 불러옵니다.

1
2
3
4
5
6
7
8
9
with open('../input/shapenetpart/shapenetpart_hdf5_2048/train0_id2file.json') as json_file:
    train_id = json.load(json_file)
with open('../input/shapenetpart/shapenetpart_hdf5_2048/val0_id2file.json') as json_file:
    valid_id = json.load(json_file)
    
with open('../input/shapenetpart/shapenetpart_hdf5_2048/train0_id2name.json') as json_file:
    train_name = json.load(json_file)
with open('../input/shapenetpart/shapenetpart_hdf5_2048/val0_id2name.json') as json_file:
    valid_name = json.load(json_file)
1
2
3
4
5
6
7
8
9
10
11
12
train_df = pd.DataFrame({'path' : train_id, 'label': train_name})
train_df['segmentation_part_num'] = train_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])

valid_df = pd.DataFrame({'path' : valid_id, 'label': valid_name})
valid_df['segmentation_part_num'] = valid_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])

print('#'*30)
print(f"Label distribution:\n{train_df['label'].value_counts()}\n")
print('#'*30)
print(f'Total data: {train_df.shape[0]}')
print('#'*30)
train_df.head()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
##############################
Label distribution:
table         615
chair         454
airplane      335
lamp          198
car           120
guitar         88
laptop         57
knife          41
pistol         37
mug            27
skateboard     20
motorbike      19
bag            12
rocket         11
cap             9
earphone        5
Name: label, dtype: int64

##############################
Total data: 2048
##############################
path label segmentation_part_num
0 02691156/points/d4d61a35e8b568fb7f1f82f6fc8747... airplane 4
1 03636649/points/eee7062babab62aa8930422448288e... lamp 4
2 04379243/points/90992c45f7b2ee7d71a48b5339c6e0... table 3
3 02691156/points/a3c928995562fca8ca8607f540cc62... airplane 4
4 03636649/points/85335cc8e6ac212a3834555ce6c51f... lamp 4
1
2
3
4
5
6
7
8
9
10
def load_data(h5_path):
    f = h5py.File(h5_path)
    data = f['data'][:].astype('float32')
    label = f['label'][:].astype('int64')
    seg = f['seg'][:].astype('int64')
    f.close()
    return data, label, seg

train_data, train_label, train_seg = load_data(config['TRAIN_DATA_PATH'])
valid_data, valid_label, valid_seg = load_data(config['VALID_DATA_PATH'])
1
2
3
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:2: H5pyDeprecationWarning:

The default file mode will change to 'r' (read-only) in h5py 3.0. To suppress this warning, pass the mode you need to h5py.File(), or set the global default h5.get_config().default_file_mode, or set the environment variable H5PY_DEFAULT_READONLY=1. Available modes are: 'r', 'r+', 'w', 'w-'/'x', 'a'. See the docs for details.
1
2
3
4
5
6
7
8
9
10
# For quick experiment, get 20% data of the total valid set. (Total valid set has 1870 data.)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for i, (_, valid_index) in enumerate(folds.split(valid_label,valid_label.reshape(-1,))):
    valid_part_data = valid_data[valid_index]
    valid_part_label = valid_label[valid_index]
    valid_part_seg = valid_seg[valid_index]
    valid_part_df = valid_df.iloc[valid_index].reset_index(drop=True)
    break
print(f'Total valid data: {valid_data.shape[0]}')
print(f'Valid data to use: {valid_part_data.shape[0]}')
1
2
Total valid data: 1870
Valid data to use: 374

Augmentations

Point cloud에 적용될 수 있는 augmentation 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 전체 포인트 비율 조정 및 이동
def translate_pointcloud(pointcloud):
    xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
    xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
       
    translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
    return translated_pointcloud


# -0.02 ~ 0.02 범위 내에서 포인트들의 x,y,z 값이 이동
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
    N, C = pointcloud.shape
    pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
    return pointcloud


# 전체 포인트 회전
def rotate_pointcloud(pointcloud):
    theta = np.pi*2 * np.random.rand()
    rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
    pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) 
    return pointcloud

Visualization

데이터를 시각화하여 확인합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def visualize(points, label=None, name=None, seg_num=None, fname=None):
    x, y, z = np.array(points).T
    if label is not None:
        label = np.array(label).T
    layout = go.Layout(
        scene=dict(
            aspectmode='data'

        ))
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                       mode='markers',
                                       marker=dict(
                                           size=15,              
                                           color=label,          
                                           colorscale='rainbow',
                                           opacity=1.0,
                                       ))],
                    layout_title_text=f"[ShapeNet Dataset]   Label: {name},   Segmentation parts: {seg_num},   Total Points: {label.shape[0] if label is not None else None}",
                    layout=layout)
    fig.update_traces(marker=dict(size=1.0,
                                  line=dict(width=1.0,
                                            color='DarkSlateGrey')),
                      selector=dict(mode='markers'))
    fig.show()
    fig.write_html(f'{fname}.html')

1
2
3
4
5
6
visualize_index = 0
visualize(train_data[visualize_index],
          train_seg[visualize_index],
          train_df.loc[visualize_index, 'label'],
          train_df.loc[visualize_index, 'segmentation_part_num'],
          'airplane')

png


Data Loader

데이터 로더 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ShapeNetDataset(Dataset):
    def __init__(self, dataframe, data, label, seg, random_translate=False, random_jitter=False, random_rotate=False):
        super().__init__()
        self.df = dataframe
        self.data = data
        self.label = label
        self.seg = seg
        self.class_name = dataframe['label']
#         self.seg_num_all = dataframe['segmentation_part_num']
        self.random_translate = random_translate
        self.random_jitter = random_jitter
        self.random_rotate = random_rotate
        self.seg_num_all = 50
        self.seg_start_index = 0

    def __getitem__(self, index):
        points = self.data[index]
        label = self.label[index]
        seg = self.seg[index]
        c_name = self.class_name[index]
#         seg_num_all = self.seg_num_all[index]
        if self.random_translate:
            points = translate_pointcloud(points)
        if self.random_jitter:
            points = jitter_pointcloud(points)
        if self.random_rotate:
            points = rotate_pointcloud(points)
        
        # categorical vector
        label_one_hot = np.zeros((label.shape[0], 16))
        for idx in range(label.shape[0]):
            label_one_hot[idx, label[idx]] = 1
        label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
        
        
        points = torch.from_numpy(points)
        label = torch.from_numpy(label)
        seg = torch.from_numpy(seg)
        
        return points, label, label_one_hot.squeeze(1), seg, c_name#, seg_num_all 
    
    def __len__(self):
        return self.df.shape[0]

DGCNN Model

DGCNN 모델 아키텍쳐를 구현한 코드입니다. 구현 코드는 [3]을 참고하여 작성되었습니다.

1
2
3
4
5
6
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x) # (batch_size, num_points, 3) x (batch_size, 3, num_points) -> (batcH_size, num_points, num_points) 
    xx = torch.sum(x**2, dim=1, keepdim=True) # (batch_size, 1, num_points)
    pairwise_distance = -xx - inner - xx.transpose(2, 1) # (batcH_size, num_points, num_points) 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k) / Get the index
    return idx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def get_graph_feature(x, k=20, idx=None, dim9=False, device='cpu'):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points) # (batch_size, 3, num_points)
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points # (batch_size, 1, 1)
    idx = idx + idx_base # update index numbers
    idx = idx.view(-1) # (batch_size * num_points * k)
    _, num_dims, _ = x.size() # num_dims = 3
    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :] # (batch_size * num_points * k, 3)
    feature = feature.view(batch_size, num_points, k, num_dims) # (batch_size, num_points, k, 3)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature      # (batch_size, 2*num_dims, num_points, k)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Transform_Net(nn.Module):
    def __init__(self, config):
        super(Transform_Net, self).__init__()
        self.config = config
        self.k = 3
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)
        self.transform = nn.Linear(256, 3*3)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
        x = self.conv3(x)                       # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 1024, num_points) -> (batch_size, 1024)
        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)     # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)     # (batch_size, 512) -> (batch_size, 256)
        x = self.transform(x)                   # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)            # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class DGCNN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.seg_num_all = config['SEG_NUM_ALL']
        self.k = config['K']
        self.transform_net = Transform_Net(config)
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(config['EMB_DIMS'])
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, config['EMB_DIMS'], kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=config['DROP_OUT'])
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=config['DROP_OUT'])
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                   self.bn10,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)
        

    def forward(self, x, l):
        batch_size = x.size(0)
        num_points = x.size(2)
        device = self.config['DEVICE']
        x0 = get_graph_feature(x, k=self.k, device=device)     # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)              # (batch_size, 3, 3)
        x = x.transpose(2, 1)                   # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)                     # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)                   # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        x = get_graph_feature(x, k=self.k, device=device)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k, device=device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k, device=device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)      # (batch_size, 64*3, num_points)

        x = self.conv6(x)                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        l = l.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
        l = self.conv7(l)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)

        x = torch.cat((x, l), dim=1)            # (batch_size, 1088, 1)
        x = x.repeat(1, 1, num_points)          # (batch_size, 1088, num_points)

        x = torch.cat((x, x1, x2, x3), dim=1)   # (batch_size, 1088+64*3, num_points)

        x = self.conv8(x)                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)
        
        return x

모델 성능 지표 - IoU

모델의 성능 확인을 위해 IoU(Intersection over Union)을 사용했습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def calculate_shape_IoU(pred_np, seg_np, label):
    
    label = label.squeeze()
    shape_ious = []
    
    for shape_idx in range(seg_np.shape[0]): 
        # class별 segmentation index가 다르므로, 해당 class의 index 범위로 설정해줘야 한다.
        start_index = shapenetpart_seg_start_index[label[shape_idx]]
        num = shapenetpart_seg_num[label[shape_idx]]
        parts = range(start_index, start_index + num)
        part_ious = []
        
        for part in parts:
            I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
            U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
            if U == 0:
                iou = 1  # If the union of groundtruth and prediction points is empty, then count part IoU as 1
            else:
                iou = I / float(U)
            part_ious.append(iou)
        shape_ious.append(np.mean(part_ious))
    return shape_ious

Train

전체 학습 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def train(config):
    
    train_dataset = ShapeNetDataset(train_df, train_data, train_label, train_seg)
    valid_dataset = ShapeNetDataset(valid_part_df, valid_part_data, valid_part_label, valid_part_seg)

    train_loader  = DataLoader(train_dataset, batch_size=config['BATCH_SIZE'], shuffle=True, drop_last=True, num_workers=4)
    valid_loader  = DataLoader(valid_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, drop_last=True, num_workers=4)
    
    
    model = DGCNN(config).to(config['DEVICE'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['LEARNING_RATE'], weight_decay=config['WEIGHT_DECAY'])
    criterion = nn.CrossEntropyLoss()
    
    device = config['DEVICE']
    best_valid_iou = 0.0

    print(f'{"#"*30} Start Training {"#"*30}')
    for epoch in range(config['EPOCHS']):
        ############
        ## TRAIN  ##
        ############
        running_loss = 0.0
        model.train()
        total_step = len(train_loader)

        train_true_cls = []
        train_pred_cls = []
        train_true_seg = []
        train_pred_seg = []
        train_label_seg = []

        for i, (point, label, label_one_hot, seg, c_name) in enumerate(train_loader):

            point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
            point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

            seg_prediction = model(point, label_one_hot) 
            seg_prediction = seg_prediction.permute(0, 2, 1).contiguous() # (batch_size, seg_all_num, num_points) -> (batch_size, num_points, seg_all_num)

            loss = criterion(seg_prediction.view(-1, config['SEG_NUM_ALL']),  # (batch_size, num_points, 50) -> (batch_size x num_points, 50)
                             seg.view(-1,1).squeeze())        # (batch_size, num_points) -> (batch_size x num_points)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            pred = seg_prediction.max(dim=2)[1] # (batch_size x num_points) - seg_num_all(50) 기준으로 max 값의 index 추출
            seg_np = seg.cpu().numpy()  # (batch_size, num_points)
            pred_np = pred.detach().cpu().numpy() # (batch_size, num_points)

            # For Accuracy computation
            train_true_cls.append(seg_np.reshape(-1)) # (batch_size * num_points, )
            train_pred_cls.append(pred_np.reshape(-1)) # (batch_size * num_points, )

            train_true_seg.append(seg_np)
            train_pred_seg.append(pred_np)
            train_label_seg.append(label.reshape(-1))


        train_true_cls = np.concatenate(train_true_cls)
        train_pred_cls = np.concatenate(train_pred_cls)
        train_accuracy = sum(train_pred_cls == train_true_cls) / len(train_pred_cls) # Accuracy

        train_true_seg = np.concatenate(train_true_seg, axis=0)
        train_pred_seg = np.concatenate(train_pred_seg, axis=0)
        train_label_seg = np.concatenate(train_label_seg)
        train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg) # IoU
        torch.cuda.empty_cache()

        ############
        ## VALID  ##
        ############    
        valid_loss = 0.0
        model.eval()
        valid_true_cls = []
        valid_pred_cls = []
        valid_true_seg = []
        valid_pred_seg = []
        valid_label_seg = []

        with torch.no_grad():
            for i, (point, label, label_one_hot, seg, c_name) in enumerate(valid_loader):
                point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
                point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
                seg_prediction = model(point, label_one_hot) 
                seg_prediction = seg_prediction.permute(0, 2, 1).contiguous()
                loss = criterion(seg_prediction.view(-1, config['SEG_NUM_ALL']),  
                                         seg.view(-1,1).squeeze())   
                pred = seg_prediction.max(dim=2)[1]
                valid_loss += loss.item()
                seg_np = seg.cpu().numpy()
                pred_np = pred.detach().cpu().numpy()

                valid_true_cls.append(seg_np.reshape(-1))
                valid_pred_cls.append(pred_np.reshape(-1))

                valid_true_seg.append(seg_np)
                valid_pred_seg.append(pred_np)
                valid_label_seg.append(label.reshape(-1))


            valid_true_cls = np.concatenate(valid_true_cls)
            valid_pred_cls = np.concatenate(valid_pred_cls)
            valid_accuracy = sum(valid_pred_cls == valid_true_cls) / len(valid_pred_cls) # Accuracy

            valid_true_seg = np.concatenate(valid_true_seg, axis=0)
            valid_pred_seg = np.concatenate(valid_pred_seg, axis=0)
            valid_label_seg = np.concatenate(valid_label_seg)
            valid_ious = calculate_shape_IoU(valid_pred_seg, valid_true_seg, valid_label_seg) # IoU



            print("Epoch: {}/{}.. ".format(epoch + 1, config['EPOCHS']) +
                                  "Loss: {:.5f}.. ".format(running_loss / total_step) +
                                  "IoU: {:.5f}.. ".format(np.mean(train_ious)) + 
                                  "Accuracy: {:.5f}.. ".format(train_accuracy) + 
                                  "Valid Loss: {:.5f}.. ".format(valid_loss / len(valid_loader)) +
                                  "Valid Accuracy: {:.5f}.. ".format(valid_accuracy) +
                                  "Valid IoU: {:.5f}..".format(np.mean(valid_ious)))

        # Early Stopping
        if np.mean(valid_ious) >= best_valid_iou:
            best_valid_iou = np.mean(valid_ious)
            torch.save(model.state_dict(), f'epoch{str(epoch + 1).zfill(3)}_seg.tar')  
        torch.cuda.empty_cache()
1
train(config)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    Epoch: 1/200.. Loss: 1.31692.. IoU: 0.54590.. Accuracy: 0.70046.. Valid Loss: 0.65172.. Valid Accuracy: 0.81989.. Valid IoU: 0.66308..
    Epoch: 2/200.. Loss: 0.55490.. IoU: 0.68076.. Accuracy: 0.84702.. Valid Loss: 0.39328.. Valid Accuracy: 0.87850.. Valid IoU: 0.73454..
    Epoch: 3/200.. Loss: 0.38516.. IoU: 0.72497.. Accuracy: 0.88587.. Valid Loss: 0.26998.. Valid Accuracy: 0.90768.. Valid IoU: 0.78491..
    Epoch: 4/200.. Loss: 0.31265.. IoU: 0.74951.. Accuracy: 0.90163.. Valid Loss: 0.26334.. Valid Accuracy: 0.90714.. Valid IoU: 0.78399..
    Epoch: 5/200.. Loss: 0.27762.. IoU: 0.76045.. Accuracy: 0.90901.. Valid Loss: 0.25485.. Valid Accuracy: 0.90983.. Valid IoU: 0.80047..
    Epoch: 6/200.. Loss: 0.25684.. IoU: 0.77536.. Accuracy: 0.91510.. Valid Loss: 0.22206.. Valid Accuracy: 0.91785.. Valid IoU: 0.81682..
    Epoch: 7/200.. Loss: 0.24241.. IoU: 0.78215.. Accuracy: 0.91883.. Valid Loss: 0.22032.. Valid Accuracy: 0.92526.. Valid IoU: 0.81722..
    Epoch: 8/200.. Loss: 0.23585.. IoU: 0.78331.. Accuracy: 0.91991.. Valid Loss: 0.21144.. Valid Accuracy: 0.92591.. Valid IoU: 0.82841..
    Epoch: 9/200.. Loss: 0.22112.. IoU: 0.79032.. Accuracy: 0.92524.. Valid Loss: 0.23086.. Valid Accuracy: 0.92256.. Valid IoU: 0.82251..
    Epoch: 10/200.. Loss: 0.21806.. IoU: 0.79583.. Accuracy: 0.92532.. Valid Loss: 0.21657.. Valid Accuracy: 0.92456.. Valid IoU: 0.82978..
    Epoch: 11/200.. Loss: 0.21182.. IoU: 0.79948.. Accuracy: 0.92750.. Valid Loss: 0.24648.. Valid Accuracy: 0.91635.. Valid IoU: 0.80556..
    Epoch: 12/200.. Loss: 0.21545.. IoU: 0.79723.. Accuracy: 0.92640.. Valid Loss: 0.22097.. Valid Accuracy: 0.92159.. Valid IoU: 0.81647..
    Epoch: 13/200.. Loss: 0.20051.. IoU: 0.80254.. Accuracy: 0.93094.. Valid Loss: 0.23823.. Valid Accuracy: 0.91569.. Valid IoU: 0.81353..
    Epoch: 14/200.. Loss: 0.19512.. IoU: 0.80915.. Accuracy: 0.93247.. Valid Loss: 0.21295.. Valid Accuracy: 0.92765.. Valid IoU: 0.82732..
    Epoch: 15/200.. Loss: 0.19006.. IoU: 0.81226.. Accuracy: 0.93412.. Valid Loss: 0.19994.. Valid Accuracy: 0.92901.. Valid IoU: 0.83136..
    Epoch: 16/200.. Loss: 0.18558.. IoU: 0.81408.. Accuracy: 0.93566.. Valid Loss: 0.21710.. Valid Accuracy: 0.92464.. Valid IoU: 0.82064..
    Epoch: 17/200.. Loss: 0.19991.. IoU: 0.80636.. Accuracy: 0.93024.. Valid Loss: 0.21422.. Valid Accuracy: 0.92818.. Valid IoU: 0.83190..
    Epoch: 18/200.. Loss: 0.18842.. IoU: 0.81347.. Accuracy: 0.93385.. Valid Loss: 0.22044.. Valid Accuracy: 0.91553.. Valid IoU: 0.82048..
    Epoch: 19/200.. Loss: 0.18017.. IoU: 0.81823.. Accuracy: 0.93626.. Valid Loss: 0.21378.. Valid Accuracy: 0.92428.. Valid IoU: 0.82102..
    Epoch: 20/200.. Loss: 0.19431.. IoU: 0.80714.. Accuracy: 0.93194.. Valid Loss: 0.21395.. Valid Accuracy: 0.92609.. Valid IoU: 0.83094..
    Epoch: 21/200.. Loss: 0.17439.. IoU: 0.81997.. Accuracy: 0.93900.. Valid Loss: 0.20616.. Valid Accuracy: 0.93017.. Valid IoU: 0.83651..
    Epoch: 22/200.. Loss: 0.17570.. IoU: 0.82240.. Accuracy: 0.93820.. Valid Loss: 0.24689.. Valid Accuracy: 0.91422.. Valid IoU: 0.81267..
    Epoch: 23/200.. Loss: 0.17458.. IoU: 0.82025.. Accuracy: 0.93858.. Valid Loss: 0.22000.. Valid Accuracy: 0.92761.. Valid IoU: 0.83319..
    Epoch: 24/200.. Loss: 0.16986.. IoU: 0.82343.. Accuracy: 0.94004.. Valid Loss: 0.19641.. Valid Accuracy: 0.93258.. Valid IoU: 0.84404..
    Epoch: 25/200.. Loss: 0.16861.. IoU: 0.82677.. Accuracy: 0.94052.. Valid Loss: 0.19602.. Valid Accuracy: 0.93314.. Valid IoU: 0.84136..
    Epoch: 26/200.. Loss: 0.17455.. IoU: 0.82170.. Accuracy: 0.93848.. Valid Loss: 0.18726.. Valid Accuracy: 0.93133.. Valid IoU: 0.83599..
    Epoch: 27/200.. Loss: 0.17093.. IoU: 0.82359.. Accuracy: 0.93934.. Valid Loss: 0.18854.. Valid Accuracy: 0.93474.. Valid IoU: 0.84933..
    Epoch: 28/200.. Loss: 0.16073.. IoU: 0.82810.. Accuracy: 0.94255.. Valid Loss: 0.18518.. Valid Accuracy: 0.93435.. Valid IoU: 0.84498..
    Epoch: 29/200.. Loss: 0.17026.. IoU: 0.82506.. Accuracy: 0.93964.. Valid Loss: 0.38316.. Valid Accuracy: 0.88538.. Valid IoU: 0.74567..
    Epoch: 30/200.. Loss: 0.17922.. IoU: 0.81712.. Accuracy: 0.93703.. Valid Loss: 0.22341.. Valid Accuracy: 0.91714.. Valid IoU: 0.81960..
    Epoch: 31/200.. Loss: 0.16245.. IoU: 0.82867.. Accuracy: 0.94206.. Valid Loss: 0.22318.. Valid Accuracy: 0.92237.. Valid IoU: 0.82672..
    Epoch: 32/200.. Loss: 0.17099.. IoU: 0.82517.. Accuracy: 0.93954.. Valid Loss: 0.22027.. Valid Accuracy: 0.92594.. Valid IoU: 0.82357..
    Epoch: 33/200.. Loss: 0.15605.. IoU: 0.83051.. Accuracy: 0.94452.. Valid Loss: 0.18681.. Valid Accuracy: 0.93503.. Valid IoU: 0.84846..
    Epoch: 34/200.. Loss: 0.15843.. IoU: 0.82996.. Accuracy: 0.94284.. Valid Loss: 0.20392.. Valid Accuracy: 0.92836.. Valid IoU: 0.84245..
    Epoch: 35/200.. Loss: 0.16202.. IoU: 0.82617.. Accuracy: 0.94193.. Valid Loss: 0.24778.. Valid Accuracy: 0.91084.. Valid IoU: 0.77830..
    Epoch: 36/200.. Loss: 0.15833.. IoU: 0.82924.. Accuracy: 0.94326.. Valid Loss: 0.21210.. Valid Accuracy: 0.93203.. Valid IoU: 0.84168..
    Epoch: 37/200.. Loss: 0.15254.. IoU: 0.83496.. Accuracy: 0.94545.. Valid Loss: 0.19363.. Valid Accuracy: 0.93366.. Valid IoU: 0.84353..
    Epoch: 38/200.. Loss: 0.15498.. IoU: 0.83194.. Accuracy: 0.94453.. Valid Loss: 0.21816.. Valid Accuracy: 0.92127.. Valid IoU: 0.80463..
    Epoch: 39/200.. Loss: 0.16115.. IoU: 0.82815.. Accuracy: 0.94203.. Valid Loss: 0.25720.. Valid Accuracy: 0.91437.. Valid IoU: 0.80738..
    Epoch: 40/200.. Loss: 0.16718.. IoU: 0.82467.. Accuracy: 0.94033.. Valid Loss: 0.18402.. Valid Accuracy: 0.93358.. Valid IoU: 0.84360..
    Epoch: 41/200.. Loss: 0.15862.. IoU: 0.83146.. Accuracy: 0.94304.. Valid Loss: 0.29855.. Valid Accuracy: 0.90643.. Valid IoU: 0.79258..
    Epoch: 42/200.. Loss: 0.15664.. IoU: 0.83043.. Accuracy: 0.94403.. Valid Loss: 0.20042.. Valid Accuracy: 0.93113.. Valid IoU: 0.84125..
    Epoch: 43/200.. Loss: 0.14966.. IoU: 0.83759.. Accuracy: 0.94630.. Valid Loss: 0.18705.. Valid Accuracy: 0.93755.. Valid IoU: 0.84825..
    Epoch: 44/200.. Loss: 0.14887.. IoU: 0.83623.. Accuracy: 0.94683.. Valid Loss: 0.19898.. Valid Accuracy: 0.93518.. Valid IoU: 0.84447..
    Epoch: 45/200.. Loss: 0.15301.. IoU: 0.83464.. Accuracy: 0.94513.. Valid Loss: 0.16970.. Valid Accuracy: 0.94148.. Valid IoU: 0.85462..
    Epoch: 46/200.. Loss: 0.14684.. IoU: 0.83714.. Accuracy: 0.94734.. Valid Loss: 0.27009.. Valid Accuracy: 0.90795.. Valid IoU: 0.80442..
    Epoch: 47/200.. Loss: 0.14504.. IoU: 0.83595.. Accuracy: 0.94767.. Valid Loss: 0.19581.. Valid Accuracy: 0.93256.. Valid IoU: 0.84187..
    Epoch: 48/200.. Loss: 0.15833.. IoU: 0.83007.. Accuracy: 0.94335.. Valid Loss: 0.26672.. Valid Accuracy: 0.92026.. Valid IoU: 0.82362..
    Epoch: 49/200.. Loss: 0.16092.. IoU: 0.82581.. Accuracy: 0.94222.. Valid Loss: 0.17721.. Valid Accuracy: 0.94083.. Valid IoU: 0.85759..
    Epoch: 50/200.. Loss: 0.14550.. IoU: 0.83773.. Accuracy: 0.94765.. Valid Loss: 0.20740.. Valid Accuracy: 0.92963.. Valid IoU: 0.83126..
    Epoch: 51/200.. Loss: 0.14189.. IoU: 0.83956.. Accuracy: 0.94905.. Valid Loss: 0.18474.. Valid Accuracy: 0.93543.. Valid IoU: 0.83999..
    Epoch: 52/200.. Loss: 0.15474.. IoU: 0.83124.. Accuracy: 0.94438.. Valid Loss: 0.47906.. Valid Accuracy: 0.82504.. Valid IoU: 0.71061..
    Epoch: 53/200.. Loss: 0.15270.. IoU: 0.83258.. Accuracy: 0.94569.. Valid Loss: 0.17861.. Valid Accuracy: 0.93641.. Valid IoU: 0.84815..
    Epoch: 54/200.. Loss: 0.14379.. IoU: 0.84142.. Accuracy: 0.94829.. Valid Loss: 0.19526.. Valid Accuracy: 0.93123.. Valid IoU: 0.83683..
    Epoch: 55/200.. Loss: 0.14429.. IoU: 0.84052.. Accuracy: 0.94802.. Valid Loss: 0.20402.. Valid Accuracy: 0.93151.. Valid IoU: 0.83997..
    Epoch: 56/200.. Loss: 0.14514.. IoU: 0.83779.. Accuracy: 0.94734.. Valid Loss: 0.23526.. Valid Accuracy: 0.92803.. Valid IoU: 0.83161..
    Epoch: 57/200.. Loss: 0.14962.. IoU: 0.83497.. Accuracy: 0.94604.. Valid Loss: 0.19088.. Valid Accuracy: 0.93505.. Valid IoU: 0.84819..
    Epoch: 58/200.. Loss: 0.14637.. IoU: 0.83939.. Accuracy: 0.94744.. Valid Loss: 0.21823.. Valid Accuracy: 0.92583.. Valid IoU: 0.83253..
    Epoch: 59/200.. Loss: 0.15409.. IoU: 0.83140.. Accuracy: 0.94492.. Valid Loss: 0.40655.. Valid Accuracy: 0.86304.. Valid IoU: 0.75745..
    Epoch: 60/200.. Loss: 0.14265.. IoU: 0.83913.. Accuracy: 0.94861.. Valid Loss: 0.19094.. Valid Accuracy: 0.93554.. Valid IoU: 0.84706..
    Epoch: 61/200.. Loss: 0.14286.. IoU: 0.84096.. Accuracy: 0.94832.. Valid Loss: 0.20205.. Valid Accuracy: 0.93043.. Valid IoU: 0.81473..
    Epoch: 62/200.. Loss: 0.14355.. IoU: 0.83862.. Accuracy: 0.94835.. Valid Loss: 0.20353.. Valid Accuracy: 0.93093.. Valid IoU: 0.83541..
    Epoch: 63/200.. Loss: 0.14098.. IoU: 0.84437.. Accuracy: 0.94908.. Valid Loss: 0.25240.. Valid Accuracy: 0.91294.. Valid IoU: 0.81494..
    Epoch: 64/200.. Loss: 0.14083.. IoU: 0.83992.. Accuracy: 0.94895.. Valid Loss: 0.20649.. Valid Accuracy: 0.93593.. Valid IoU: 0.84198..
    Epoch: 65/200.. Loss: 0.15158.. IoU: 0.83413.. Accuracy: 0.94502.. Valid Loss: 0.20826.. Valid Accuracy: 0.93343.. Valid IoU: 0.84039..
    Epoch: 66/200.. Loss: 0.16724.. IoU: 0.82294.. Accuracy: 0.94009.. Valid Loss: 0.23660.. Valid Accuracy: 0.92282.. Valid IoU: 0.81949..
    Epoch: 67/200.. Loss: 0.14525.. IoU: 0.83594.. Accuracy: 0.94782.. Valid Loss: 0.21190.. Valid Accuracy: 0.93242.. Valid IoU: 0.84279..
    Epoch: 68/200.. Loss: 0.13772.. IoU: 0.84340.. Accuracy: 0.95039.. Valid Loss: 0.18135.. Valid Accuracy: 0.93932.. Valid IoU: 0.85039..
    Epoch: 69/200.. Loss: 0.13618.. IoU: 0.84754.. Accuracy: 0.95095.. Valid Loss: 0.20335.. Valid Accuracy: 0.93110.. Valid IoU: 0.83127..
    Epoch: 70/200.. Loss: 0.13659.. IoU: 0.84179.. Accuracy: 0.95069.. Valid Loss: 0.18244.. Valid Accuracy: 0.93982.. Valid IoU: 0.85437..
    Epoch: 71/200.. Loss: 0.14031.. IoU: 0.83734.. Accuracy: 0.94919.. Valid Loss: 0.21056.. Valid Accuracy: 0.92931.. Valid IoU: 0.84468..
    Epoch: 72/200.. Loss: 0.13603.. IoU: 0.84462.. Accuracy: 0.95114.. Valid Loss: 0.24763.. Valid Accuracy: 0.92385.. Valid IoU: 0.82229..
    Epoch: 73/200.. Loss: 0.14204.. IoU: 0.84362.. Accuracy: 0.94875.. Valid Loss: 0.25850.. Valid Accuracy: 0.91415.. Valid IoU: 0.81627..
    Epoch: 74/200.. Loss: 0.14324.. IoU: 0.83919.. Accuracy: 0.94892.. Valid Loss: 0.24689.. Valid Accuracy: 0.92635.. Valid IoU: 0.83444..
    Epoch: 75/200.. Loss: 0.14355.. IoU: 0.83870.. Accuracy: 0.94826.. Valid Loss: 0.18079.. Valid Accuracy: 0.93849.. Valid IoU: 0.85249..
    Epoch: 76/200.. Loss: 0.13630.. IoU: 0.84431.. Accuracy: 0.95086.. Valid Loss: 0.25681.. Valid Accuracy: 0.91564.. Valid IoU: 0.80750..
    Epoch: 77/200.. Loss: 0.13581.. IoU: 0.84146.. Accuracy: 0.95097.. Valid Loss: 0.19432.. Valid Accuracy: 0.93582.. Valid IoU: 0.84452..
    Epoch: 78/200.. Loss: 0.13732.. IoU: 0.84624.. Accuracy: 0.95060.. Valid Loss: 0.21660.. Valid Accuracy: 0.92575.. Valid IoU: 0.83615..
    Epoch: 79/200.. Loss: 0.14015.. IoU: 0.84498.. Accuracy: 0.94937.. Valid Loss: 0.24544.. Valid Accuracy: 0.92610.. Valid IoU: 0.82973..
    Epoch: 80/200.. Loss: 0.16115.. IoU: 0.82898.. Accuracy: 0.94340.. Valid Loss: 0.17465.. Valid Accuracy: 0.93654.. Valid IoU: 0.84511..
    Epoch: 81/200.. Loss: 0.15004.. IoU: 0.83131.. Accuracy: 0.94588.. Valid Loss: 0.19234.. Valid Accuracy: 0.93430.. Valid IoU: 0.83985..
    Epoch: 82/200.. Loss: 0.14164.. IoU: 0.84265.. Accuracy: 0.94882.. Valid Loss: 0.18463.. Valid Accuracy: 0.93401.. Valid IoU: 0.84632..
    Epoch: 83/200.. Loss: 0.13916.. IoU: 0.84483.. Accuracy: 0.95011.. Valid Loss: 0.19545.. Valid Accuracy: 0.93725.. Valid IoU: 0.84867..
    Epoch: 84/200.. Loss: 0.14014.. IoU: 0.84242.. Accuracy: 0.94960.. Valid Loss: 0.18799.. Valid Accuracy: 0.93462.. Valid IoU: 0.84691..
    Epoch: 85/200.. Loss: 0.13429.. IoU: 0.84732.. Accuracy: 0.95123.. Valid Loss: 0.24446.. Valid Accuracy: 0.91742.. Valid IoU: 0.80324..

Test

학습에 사용되지 않은 데이터에 대해 예측해보는 test 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def test(config):
    with open('../input/shapenetpart/shapenetpart_hdf5_2048/test0_id2file.json') as json_file:
        test_id = json.load(json_file)
    with open('../input/shapenetpart/shapenetpart_hdf5_2048/test0_id2name.json') as json_file:
        test_name = json.load(json_file)
        
    test_df = pd.DataFrame({'path' : test_id, 'label': test_name})
    test_df['segmentation_part_num'] = test_df['label'].apply(lambda x : shapenetpart_seg_num[shapenetpart_cat2id[x]])
    test_data, test_label, test_seg = load_data(config['TEST_DATA_PATH'])
    
    test_dataset = ShapeNetDataset(test_df, test_data, test_label, test_seg)
    test_loader  = DataLoader(test_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, drop_last=False, num_workers=4)

    model = DGCNN(config).to(config['DEVICE'])
    model.load_state_dict(torch.load('../input/shapenet-pretrained/epoch049_seg.tar', map_location=config['DEVICE']))
    model.eval()
    
    device = config['DEVICE']

    test_true_cls = []
    test_pred_cls = []
    test_true_seg = []
    test_pred_seg = []
    test_label_seg = []
    test_point = []
    with torch.no_grad():
        for i, (point, label, label_one_hot, seg, c_name) in tqdm(enumerate(test_loader), total=len(test_loader)):
            point, label_one_hot, seg = point.to(device), label_one_hot.to(device), seg.to(device)
            point = point.permute(0, 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
            seg_prediction = model(point, label_one_hot) 
            seg_prediction = seg_prediction.permute(0, 2, 1).contiguous()

            pred = seg_prediction.max(dim=2)[1]
            seg_np = seg.cpu().numpy()
            pred_np = pred.detach().cpu().numpy()

            test_true_cls.append(seg_np.reshape(-1))
            test_pred_cls.append(pred_np.reshape(-1))

            test_true_seg.append(seg_np)
            test_pred_seg.append(pred_np)
            test_label_seg.append(label.reshape(-1))

            test_point.append(point.permute(0,2,1).detach().cpu().numpy())

        test_true_cls = np.concatenate(test_true_cls)
        test_pred_cls = np.concatenate(test_pred_cls)
        test_accuracy = sum(test_pred_cls == test_true_cls) / len(test_pred_cls) # Accuracy

        test_true_seg = np.concatenate(test_true_seg, axis=0)
        test_pred_seg = np.concatenate(test_pred_seg, axis=0)
        test_label_seg = np.concatenate(test_label_seg)
        test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg) # IoU

        test_point = np.concatenate(test_point)
    print('Inference fin... ')
    print("Test IoU: {:.5f}.. ".format(np.mean(test_ious)) + 
          "Test Accuracy: {:.5f}.. ".format(test_accuracy))   
    
    return test_true_seg, test_pred_seg, test_point,  test_df, test_ious
1
test_true_seg, test_pred_seg, test_point, test_df, test_ious = test(config)
1
2
3
4
5
6
7
8
9
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:2: H5pyDeprecationWarning:

The default file mode will change to 'r' (read-only) in h5py 3.0. To suppress this warning, pass the mode you need to h5py.File(), or set the global default h5.get_config().default_file_mode, or set the environment variable H5PY_DEFAULT_READONLY=1. Available modes are: 'r', 'r+', 'w', 'w-'/'x', 'a'. See the docs for details.

100%|██████████| 128/128 [00:29<00:00,  4.34it/s]


Inference fin... 
Test IoU: 0.81172.. Test Accuracy: 0.92177.. 

Prediction result

모델의 예측 결과입니다.

1
2
3
4
5
6
7
idx = 2045
visualize(test_point[idx], 
          test_pred_seg[idx],
          test_df.loc[idx, 'label'],
          test_df.loc[idx, 'segmentation_part_num'],
          'chair_pred')
print(f'Test IoU: {test_ious[idx]}')
1
1
Test IoU: 0.9765684201114649

png


True label

실제 레이블입니다.

1
2
3
4
5
6
visualize(test_point[idx], 
          test_true_seg[idx],
          test_df.loc[idx, 'label'],
          test_df.loc[idx, 'segmentation_part_num'],
          'chair_true')
print(f'Test IoU: {test_ious[idx]}')
1
1
Test IoU: 0.9765684201114649

png


참고 자료

[1] DGCNN paper: https://arxiv.org/abs/1801.07829v2
[2] PointNet paper: https://arxiv.org/abs/1612.00593
[3] Spatial Transformer Networks paper: https://arxiv.org/abs/1506.02025, NIPS 2015
[4] DGCNN 설명: https://www.youtube.com/watch?v=Rv3osRZWGbg
[5] PointNet 설명: https://jhyeup.tistory.com/entry/PointNet
[6] Spatial Transformer Networks 설명: https://www.youtube.com/watch?v=Rv3osRZWGbg
[7] Model code: https://github.com/AnTao97/dgcnn.pytorch/blob/master/model.py
[8] Dataset: https://github.com/AnTao97/PointCloudDatasets
[9] Data loading code : https://github.com/AnTao97/PointCloudDatasets/blob/master/dataset.py

0%