import os.path, copy, numpy as np, time, sys
from scipy.optimize import linear_sum_assignment
from scipy.spatial import ConvexHull
import open3d as o3d
import os
from alfred.fusion.common import draw_3d_box, compute_3d_box_lidar_coords
from alfred.fusion.kitti_fusion import load_pc_from_file
from collections import deque
def color_sep(label):
if label == 0:
color = [0.5, 0, 1] # unlabled
if label == 1:
color = [1, 0, 0] # car
if label == 2:
color = [1, 1, 0] # bicycle
if label == 3:
color = [0, 1, 1] # truck
if label == 4:
color = [1, 0, 1] # person
if label == 5:
color = [0, 1, 0] # road
if label == 6:
color = [0, 0, 1] # building
return color
def visual_option():
vis = o3d.visualization.Visualizer()
vis.create_window()
opt = vis.get_render_option()
opt.background_color = np.asarray([0, 0, 0])
opt.point_size = 1
opt.line_width = 100
opt.show_coordinate_frame = True
return vis
def file_filter(f):
if f[-4:] in ['.txt']:
return True
else:
return False
if __name__ == '__main__':
if len(sys.argv)!=2:
print("Usage: python main.py result_sha(e.g., car_3d_det_test)")
sys.exit(1)
# ====================================Visual Preparation======================================
vis = visual_option()
# 读取viewpoint参数
param = o3d.io.read_pinhole_camera_parameters("/home/huangruixin/Downloads/script/show_benewake.json")
ctr = vis.get_view_control()
line_set = [o3d.geometry.LineSet() for _ in range(100)]
pcobj = o3d.geometry.PointCloud()
# =====================================Data Extraction========================================
result_sha = sys.argv[1]
det_id2str = {0:'unlabeled', 1:'car', 2:'bicycle', 3:'truck', 4:'person', 5:'road', 6:'building'}
seq_file_list = os.listdir(os.path.join('/home/huangruixin/Downloads/script/road_bst/', result_sha))
seq_file_list = list(filter(file_filter, seq_file_list))
suffix = '_cls.txt'
seq_file_list = [int(f.replace(suffix, '')) for f in seq_file_list]
seq_file_list.sort()
seq_file_list = [str(f) + suffix for f in seq_file_list]
total_time = 0.0
total_frames = 0
threshold = deque(maxlen=2)
while True:
for number,seq_file in enumerate(seq_file_list):
seq_name = seq_file
seq_dets = np.loadtxt(os.path.join('/home/huangruixin/Downloads/script/road_bst', result_sha, seq_name), delimiter=' ') # load detections
# ===================================Visualization==========================================
idx = '%06d' % int(seq_name.split('_')[0])
pcs = np.fromfile(f'/home/huangruixin/Downloads/script/benewake_overpass/{idx}.bin', dtype= np.float32).reshape(-1,4)
# point_colors = [[0.39, 0.58, 0.93] for i in range(len(pcs))]
point_colors = [[1, 1, 1] for i in range(len(pcs))]
pcobj.colors = o3d.utility.Vector3dVector(point_colors) # change point cloud color
if pcobj.is_empty():
pcobj.points = o3d.utility.Vector3dVector(pcs[:, 0:3])
vis.add_geometry(pcobj)
else:
pcobj.points = o3d.utility.Vector3dVector(pcs[:, 0:3])
vis.update_geometry(pcobj)
# =================================Draw Detection result====================================
# for index, d in enumerate(trackers):
for index, d in enumerate(seq_dets):
xyz = np.array([d[1:4]])
hwl = np.array([d[4:7]])
r_y = [-d[7]]
pts3d = compute_3d_box_lidar_coords(xyz, hwl, angles=r_y, origin=(0.5, 0.5, 0.5), axis=2)
lines = [[0, 1], [1, 2], [2, 3], [3, 0],
[4, 5], [5, 6], [6, 7], [7, 4],
[0, 4], [1, 5], [2, 6], [3, 7]]
color = color_sep(int(d[0]))
line_colors = [color for i in range(len(lines))]
if line_set[index].has_lines():
line_set[index].points = o3d.utility.Vector3dVector(pts3d[0])
line_set[index].lines = o3d.utility.Vector2iVector(lines)
line_set[index].colors = o3d.utility.Vector3dVector(line_colors)
vis.update_geometry(line_set[index])
else:
line_set[index].points = o3d.utility.Vector3dVector(pts3d[0])
line_set[index].lines = o3d.utility.Vector2iVector(lines)
line_set[index].colors = o3d.utility.Vector3dVector(line_colors)
vis.add_geometry(line_set[index])
threshold.append(index)
print("num dets is {}".format(len(seq_dets)))
# print("index is {}".format(index))
if number != 0:
if threshold[1] < threshold[0]:
print(f"thresh: {threshold}")
for i in range(threshold[1] + 1, threshold[0] + 1):
vis.remove_geometry(line_set[i])
# 转换视角
ctr.convert_from_pinhole_camera_parameters(param)
vis.poll_events()
vis.update_renderer()
# # 存储固定视角参数
# param = vis.get_view_control().convert_to_pinhole_camera_parameters()
# o3d.io.write_pinhole_camera_parameters("/home/huangruixin/Downloads/script/show_benewake.json", param)
# vis.run()
print("{} has been visualized".format(seq_name))
# print("Total Tracking took: %.3f for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))