3D检测框+点云多帧可视化

 

open3d检测框加点云多帧可视化

import os.path, copy, numpy as np, time, sys
from scipy.optimize import linear_sum_assignment
from scipy.spatial import ConvexHull
import open3d as o3d
import os
from alfred.fusion.common import draw_3d_box, compute_3d_box_lidar_coords
from alfred.fusion.kitti_fusion import load_pc_from_file
from collections import deque

def color_sep(label):
    if label == 0:
        color = [0.5, 0, 1] # unlabled
    if label == 1:
        color = [1, 0, 0] # car
    if label == 2:
        color = [1, 1, 0] # bicycle
    if label == 3:
        color = [0, 1, 1] # truck
    if label == 4:
        color = [1, 0, 1] # person
    if label == 5:
        color = [0, 1, 0] # road
    if label == 6:
        color = [0, 0, 1] # building
    return color

def visual_option():
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    opt = vis.get_render_option()
    opt.background_color = np.asarray([0, 0, 0])
    opt.point_size = 1
    opt.line_width = 100
    opt.show_coordinate_frame = True
    return vis

def file_filter(f):
    if f[-4:] in ['.txt']:
        return True
    else:
        return False

if __name__ == '__main__':
  if len(sys.argv)!=2:
    print("Usage: python main.py result_sha(e.g., car_3d_det_test)")
    sys.exit(1)
  # ====================================Visual Preparation======================================
  vis = visual_option()
  # 读取viewpoint参数
  param = o3d.io.read_pinhole_camera_parameters("/home/huangruixin/Downloads/script/show_benewake.json")
  ctr = vis.get_view_control()

  line_set = [o3d.geometry.LineSet() for _ in range(100)]
  pcobj = o3d.geometry.PointCloud()
  # =====================================Data Extraction========================================
  result_sha = sys.argv[1]
  det_id2str = {0:'unlabeled', 1:'car', 2:'bicycle', 3:'truck', 4:'person', 5:'road', 6:'building'}
  seq_file_list = os.listdir(os.path.join('/home/huangruixin/Downloads/script/road_bst/', result_sha))
  seq_file_list = list(filter(file_filter, seq_file_list))
  suffix = '_cls.txt'
  seq_file_list = [int(f.replace(suffix, '')) for f in seq_file_list]
  seq_file_list.sort()
  seq_file_list = [str(f) + suffix for f in seq_file_list]
  total_time = 0.0
  total_frames = 0
  threshold = deque(maxlen=2)
  while True:
    for number,seq_file in enumerate(seq_file_list):
      seq_name = seq_file
      seq_dets = np.loadtxt(os.path.join('/home/huangruixin/Downloads/script/road_bst', result_sha, seq_name), delimiter=' ')          # load detections
      # ===================================Visualization==========================================
      idx = '%06d' % int(seq_name.split('_')[0])
      pcs = np.fromfile(f'/home/huangruixin/Downloads/script/benewake_overpass/{idx}.bin', dtype= np.float32).reshape(-1,4)
      # point_colors = [[0.39, 0.58, 0.93] for i in range(len(pcs))]
      point_colors = [[1, 1, 1] for i in range(len(pcs))]
      pcobj.colors = o3d.utility.Vector3dVector(point_colors)  # change point cloud color
      if pcobj.is_empty():
          pcobj.points = o3d.utility.Vector3dVector(pcs[:, 0:3])
          vis.add_geometry(pcobj)
      else:
          pcobj.points = o3d.utility.Vector3dVector(pcs[:, 0:3])
          vis.update_geometry(pcobj)
      # =================================Draw Detection result====================================
      # for index, d in enumerate(trackers):
      for index, d in enumerate(seq_dets):
          xyz = np.array([d[1:4]])
          hwl = np.array([d[4:7]])
          r_y = [-d[7]]
          pts3d = compute_3d_box_lidar_coords(xyz, hwl, angles=r_y, origin=(0.5, 0.5, 0.5), axis=2)
          lines = [[0, 1], [1, 2], [2, 3], [3, 0],
                  [4, 5], [5, 6], [6, 7], [7, 4],
                  [0, 4], [1, 5], [2, 6], [3, 7]]
          color = color_sep(int(d[0]))
          line_colors = [color for i in range(len(lines))]
          if line_set[index].has_lines():
              line_set[index].points = o3d.utility.Vector3dVector(pts3d[0])
              line_set[index].lines = o3d.utility.Vector2iVector(lines)
              line_set[index].colors = o3d.utility.Vector3dVector(line_colors)
              vis.update_geometry(line_set[index])
          else:
              line_set[index].points = o3d.utility.Vector3dVector(pts3d[0])
              line_set[index].lines = o3d.utility.Vector2iVector(lines)
              line_set[index].colors = o3d.utility.Vector3dVector(line_colors)
              vis.add_geometry(line_set[index])
      threshold.append(index)
      print("num dets is {}".format(len(seq_dets)))
      # print("index is {}".format(index))
      if number != 0:
          if threshold[1] < threshold[0]:
              print(f"thresh: {threshold}")
              for i in range(threshold[1] + 1, threshold[0] + 1):
                  vis.remove_geometry(line_set[i])

      # 转换视角
      ctr.convert_from_pinhole_camera_parameters(param)
      vis.poll_events()
      vis.update_renderer()

      # # 存储固定视角参数
      # param = vis.get_view_control().convert_to_pinhole_camera_parameters()
      # o3d.io.write_pinhole_camera_parameters("/home/huangruixin/Downloads/script/show_benewake.json", param)
      # vis.run()
      print("{} has been visualized".format(seq_name))
# print("Total Tracking took: %.3f for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))

参考资料

  1. Deecamp笔记——点云目标跟踪 & Open3D连续可视化
  2. AB3DMOT项目
  3. open3d官方文档–load-save-viewpoint