PointNet代码分析

PointNet代码分析

作者:陆语


##train.py (训练相关)

train_one_epoch函数

  • shuffle函数打乱输入数据的顺序
  • 使用provider.py文件中的jitter函数和rotate函数对数据作随机处理

eval_one_epoch

  • 计算acc和loss

##provider.py (获取数据集,旋转,扰动函数)

jitter_point_cloud函数

1
2
3
4
5
6
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
  • 利用randn产生三个维度上的符合高斯分布的随机数
  • 利用clip函数将三个维度的扰动限制在(sigma,-sigma)

rotate_point_cloud_by_angle函数

1
2
3
4
5
6
7
8
9
10
11
12
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
  • 原向量乘上转换矩阵,得到旋转后的向量

##pointnet_cls.py(分类网络模型)

get_model函数

Pointnet.png

1
2
3
4
with tf.variable_scope('transform_net1') as sc:
transform = input_transform_net(point_cloud, is_training, bn_decay, K=3)
point_cloud_transformed = tf.matmul(point_cloud, transform)
input_image = tf.expand_dims(point_cloud_transformed, -1)
  • T-net处理输入获得转换矩阵,矩阵相乘对齐输入
1
2
3
4
5
6
7
8
9
net = tf_util.conv2d(input_image, 64, [1,3],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv1', bn_decay=bn_decay)

net = tf_util.conv2d(net, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv2', bn_decay=bn_decay)
  • 两层卷积,即图中mlp(64,64)
1
2
3
4
5
with tf.variable_scope('transform_net2') as sc:
transform = feature_transform_net(net, is_training, bn_decay, K=64)
end_points['transform'] = transform
net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform)
net_transformed = tf.expand_dims(net_transformed, [2])
  • T-net处理特征得到转换矩阵,矩阵相乘对齐特征
1
2
3
4
5
6
7
8
9
10
11
12
net = tf_util.conv2d(net_transformed, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv3', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv4', bn_decay=bn_decay)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
  • 三层卷积继续提取特征,即mlp(64,128,1024)
1
2
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
  • 池化层作为对称函数,得到1024为的特征向量,解决点云的无序性

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    net = tf.reshape(net, [batch_size, -1])
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
    scope='fc1', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
    scope='dp1')
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
    scope='fc2', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
    scope='dp2')
    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')
  • mlp(512,256,k),这时k取40

get_loss函数

  • 利用交叉熵计算loss

transform_net.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def input_transform_net(point_cloud, is_training, bn_decay=None, K=3):
""" Input (XYZ) Transform Net, input is BxNx3 gray image
Return:
Transformation matrix of size 3xK """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value

input_image = tf.expand_dims(point_cloud, -1)
net = tf_util.conv2d(input_image, 64, [1,3],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv2', bn_decay=bn_decay)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv3', bn_decay=bn_decay)
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='tmaxpool')

net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='tfc1', bn_decay=bn_decay)
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='tfc2', bn_decay=bn_decay)

with tf.variable_scope('transform_XYZ') as sc:
assert(K==3)
weights = tf.get_variable('weights', [256, 3*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases = tf.get_variable('biases', [3*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases += tf.constant([1,0,0,0,1,0,0,0,1], dtype=tf.float32)
transform = tf.matmul(net, weights)
transform = tf.nn.bias_add(transform, biases)

transform = tf.reshape(transform, [batch_size, 3, K])
return transform
  • 针对输入的T-net,输出一个转移矩阵用于和输入相乘

本文标题:PointNet代码分析

文章作者:Cello

发布时间:2018年10月21日 - 16:10

最后更新:2018年10月21日 - 16:10

原始链接:https://littlexy.git.io/2018/10/21/PointNet代码分析/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

坚持原创技术分享,您的支持将鼓励我继续创作!