PointNet代码分析
作者:陆语
##train.py (训练相关)
train_one_epoch函数
- shuffle函数打乱输入数据的顺序
- 使用provider.py文件中的jitter函数和rotate函数对数据作随机处理
eval_one_epoch
- 计算acc和loss
##provider.py (获取数据集,旋转,扰动函数)
jitter_point_cloud函数
1 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): |
- 利用randn产生三个维度上的符合高斯分布的随机数
- 利用clip函数将三个维度的扰动限制在(sigma,-sigma)
rotate_point_cloud_by_angle函数
1 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): |
- 原向量乘上转换矩阵,得到旋转后的向量
##pointnet_cls.py(分类网络模型)
get_model函数
1 | with tf.variable_scope('transform_net1') as sc: |
- T-net处理输入获得转换矩阵,矩阵相乘对齐输入
1 | net = tf_util.conv2d(input_image, 64, [1,3], |
- 两层卷积,即图中mlp(64,64)
1 | with tf.variable_scope('transform_net2') as sc: |
- T-net处理特征得到转换矩阵,矩阵相乘对齐特征
1 | net = tf_util.conv2d(net_transformed, 64, [1,1], |
- 三层卷积继续提取特征,即mlp(64,128,1024)
1 | net = tf_util.max_pool2d(net, [num_point,1], |
池化层作为对称函数,得到1024为的特征向量,解决点云的无序性
1
2
3
4
5
6
7
8
9
10net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='fc1', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
scope='dp1')
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='fc2', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
scope='dp2')
net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')mlp(512,256,k),这时k取40
get_loss函数
- 利用交叉熵计算loss
transform_net.py
1 | def input_transform_net(point_cloud, is_training, bn_decay=None, K=3): |
- 针对输入的T-net,输出一个转移矩阵用于和输入相乘