Loading...
Searching...
No Matches
viser_visualizer.py
1"""Viser helper class for robot trajectory visualization with OMPL"""
2
3import time
4import threading
5import sys
6import termios
7import tty
8import select
9import numpy as np
10from pathlib import Path
11import viser
12from viser.extras import ViserUrdf
13from robot_descriptions.loaders.yourdfpy import load_robot_description
14from typing import Optional, List, Dict, Any
15from scipy.spatial.transform import Rotation
16
17
19 # Joint mappings for robots that use subset of URDF joints
20 JOINT_MAPPINGS = {
21 "fetch": [
22 "torso_lift_joint",
23 "shoulder_pan_joint",
24 "shoulder_lift_joint",
25 "upperarm_roll_joint",
26 "elbow_flex_joint",
27 "forearm_roll_joint",
28 "wrist_flex_joint",
29 "wrist_roll_joint",
30 ]
31 }
32
34 self, robot_name: str, robot_dimension: int, port: Optional[int] = None
35 ):
36 """Initialize the visualizer
37
38 Args:
39 robot_name: Name of the robot
40 robot_dimension: Number of degrees of freedom for the robot
41 port: Optional port number for viser server (default: 8080)
42
43 """
44 self.robot_name = robot_name
45 self.dimension = robot_dimension
46
47 # Setup joint mapping for this robot
48 self.joint_mapping = self.JOINT_MAPPINGS.get(robot_name, None)
49
50 if port is not None:
51 self.server = viser.ViserServer(port=port)
52 else:
53 self.server = viser.ViserServer()
54
55 # Load robot URDF from robot_descriptions
56 description_name = f"{robot_name}_description"
57 try:
58 # VAMP uses ur5 with a Roboriq gripper, URDF is loaded from the repository
59 if description_name == "ur5_description":
60 import yourdfpy
61
62 vamp_folder = Path(__file__).parent.parent.parent / "external" / "vamp"
63 ur5_urdf_file = vamp_folder / "resources" / "ur5" / "ur5.urdf"
64 mesh_dir = vamp_folder / "resources" / "ur5"
65
66 def package_aware_handler(fname):
67 """Resolve package:// URIs by stripping prefix and joining with mesh_dir"""
68 if fname.startswith("package://"):
69 fname = fname[len("package://") :]
70 resolved = Path(mesh_dir) / fname
71 return str(resolved)
72
73 self.robot_urdf = yourdfpy.URDF.load(
74 str(ur5_urdf_file),
75 filename_handler=package_aware_handler,
76 mesh_dir=mesh_dir,
77 load_meshes=True,
78 load_collision_meshes=True,
79 )
80 else:
81 # URDF loaded from robot_descriptions
82 self.robot_urdf = load_robot_description(description_name)
83 except Exception as e:
84 raise ValueError(
85 f"Could not load URDF for '{description_name}'. "
86 f"Make sure robot_descriptions has this robot. Error: {e}"
87 )
88
89 self.urdf_vis = ViserUrdf(
90 self.server, self.robot_urdf, root_node_name=f"/{robot_name}"
91 )
92
93 self._trajectory = None
94 self._slider = None
95 self._playing = None
96
97 def reset(self):
98 """Reset the entire scene including the robot"""
99 self.server.scene.reset()
100 self.urdf_vis = ViserUrdf(
101 self.server, self.robot_urdf, root_node_name=f"/{self.robot_name}"
102 )
103
104 # Reset internal state
105 self._trajectory = None
106 self._slider = None
107 self._playing = None
108 self._start_time = None
109
110 def set_camera(self, position, target):
111 """Set the camera position and target
112
113 Args:
114 position: Camera position as array [x, y, z]
115 target: Camera look-at target as array [x, y, z]
116 """
117 self.server.initial_camera.position = np.array(position, dtype=np.float64)
118 self.server.initial_camera.look_at = np.array(target, dtype=np.float64)
119
120 def load_mbm_environment(
121 self,
122 problem_data: Dict[str, Any],
123 ignore_names: List[str] = [],
124 color=(0.8, 0.4, 0.2, 0.7),
125 padding: float = 0.0,
126 ):
127 """Load environment from MBM problem format
128
129 Args:
130 problem_data: Dictionary containing 'sphere', 'cylinder', 'box' keys with obstacle data
131 ignore_names: List of obstacle names to ignore
132 color: RGBA color tuple (0-1 range) for obstacles
133 padding: Additional padding to add to obstacle sizes
134 """
135
136 # Helper function to convert euler angles to quaternion and create rotation matrix
137 def euler_to_rotation_matrix(euler_xyz):
138 """Convert euler angles (xyz) to rotation matrix"""
139 return Rotation.from_euler("xyz", euler_xyz).as_matrix()
140
141 # Load spheres
142 for obj in problem_data.get("sphere", []):
143 if obj["name"] not in ignore_names:
144 position = np.array(obj["position"])
145 radius = obj["radius"] + padding
146 self.add_sphere(
147 position=position,
148 radius=radius,
149 color=color,
150 name=f"/sphere_{obj['name']}",
151 )
152
153 is_box_problem = problem_data.get("problem") == "box"
154
155 # Load cylinders (or as boxes if is_box_problem)
156 for obj in problem_data.get("cylinder", []):
157 if obj["name"] in ignore_names:
158 continue
159
160 position = np.array(obj["position"])
161 orientation_euler = np.array(
162 obj.get("orientation_euler_xyz", [0.0, 0.0, 0.0])
163 )
164 rotation_matrix = euler_to_rotation_matrix(orientation_euler)
165
166 if is_box_problem:
167 # Render cylinder as box (HACK for VAMP capsule overapproximation)
168 radius = obj["radius"] + padding
169 length = obj["length"]
170 half_extents = [radius, radius, length / 2.0]
171 self.add_box(
172 position=position,
173 half_extents=half_extents,
174 rotation_matrix=rotation_matrix,
175 color=color,
176 name=f"/cylinder_as_box_{obj['name']}",
177 )
178 else:
179 radius = obj["radius"] + padding
180 length = obj["length"]
181 self.add_cylinder(
182 position=position,
183 radius=radius,
184 length=length,
185 rotation_matrix=rotation_matrix,
186 color=color,
187 name=f"/cylinder_{obj['name']}",
188 )
189
190 for obj in problem_data.get("box", []):
191 if obj["name"] not in ignore_names:
192 position = np.array(obj["position"])
193 orientation_euler = np.array(
194 obj.get("orientation_euler_xyz", [0.0, 0.0, 0.0])
195 )
196 rotation_matrix = euler_to_rotation_matrix(orientation_euler)
197 half_extents = [h + padding / 2 for h in obj["half_extents"]]
198 self.add_box(
199 position=position,
200 half_extents=half_extents,
201 rotation_matrix=rotation_matrix,
202 color=color,
203 name=f"/box_{obj['name']}",
204 )
205
206 def add_point_cloud(
207 self,
208 points: np.ndarray,
209 color: Optional[np.ndarray] = None,
210 point_size: float = 0.01,
211 ):
212 """Add a point cloud to the scene
213
214 Args:
215 points: Array of shape (N, 3) containing point coordinates
216 color: Optional array of shape (N, 3) containing RGB colors for each point
217 """
218 if color is None:
219 color = np.ones((points.shape[0], 3)) * np.array([1, 0, 0])
220
221 self.server.scene.add_point_cloud(
222 "/point_cloud", points=points, colors=color, point_size=point_size
223 )
224
225 def _generate_name(self, prefix: str) -> str:
226 """Generate a unique name based on existing objects with the same prefix
227
228 Args:
229 prefix: Name prefix (e.g., '/sphere_', '/box_', '/cylinder_')
230
231 Returns:
232 Unique name with numeric suffix
233 """
234 count = len(
235 [
236 k
237 for k in self.server.scene._handle_from_node_name.keys()
238 if k.startswith(prefix)
239 ]
240 )
241 return f"{prefix}{count}"
242
243 def _rotation_to_wxyz(self, rotation_matrix: np.ndarray) -> np.ndarray:
244 """Convert rotation matrix to wxyz quaternion format
245
246 Args:
247 rotation_matrix: 3x3 rotation matrix
248
249 Returns:
250 Quaternion in wxyz format
251 """
252 rotation = Rotation.from_matrix(rotation_matrix)
253 quat = rotation.as_quat() # Returns xyzw
254 wxyz = np.array([quat[3], quat[0], quat[1], quat[2]])
255 return wxyz
256
257 def _update_robot_config(
258 self, trajectory_idx: int, gripper_dof: float = 0.0
259 ) -> None:
260 """Update the robot configuration in the visualization
261
262 Args:
263 trajectory_idx: Index into the current trajectory
264 gripper_dof: Value for extra gripper DOF if needed (default: 0.0)
265 """
266 if self._trajectory is None:
267 return
268
269 idx = min(trajectory_idx, len(self._trajectory) - 1)
270 plan_config = self._trajectory[idx].tolist()
271 config = self._map_plan_config_to_urdf(plan_config)
272
273 self.urdf_vis.update_cfg(config)
274
275 def add_sphere(
276 self,
277 position: np.ndarray,
278 radius: float,
279 color=(1, 0, 0, 0.75),
280 name: Optional[str] = None,
281 ):
282 """Add a sphere obstacle to the scene
283
284 Args:
285 position: 3D position [x, y, z]
286 radius: Sphere radius
287 color: RGBA color tuple (0-1 range)
288 name: Optional name for the sphere (auto-generated if not provided)
289 """
290 if name is None:
291 name = self._generate_name("/sphere_")
292
293 self.server.scene.add_icosphere(
294 name=name,
295 position=tuple(position),
296 radius=radius,
297 color=color[:3] if len(color) == 4 else color,
298 opacity=color[3] if len(color) == 4 else 1.0,
299 )
300
301 def add_box(
302 self,
303 position: np.ndarray,
304 half_extents: List[float],
305 rotation_matrix: Optional[np.ndarray] = None,
306 color=(0.8, 0.4, 0.2, 0.75),
307 name: Optional[str] = None,
308 ):
309 """Add a box obstacle to the scene
310
311 Args:
312 position: 3D position [x, y, z]
313 half_extents: Half extents [x, y, z] (full size will be 2x these values)
314 rotation_matrix: 3x3 rotation matrix (identity if None)
315 color: RGBA color tuple (0-1 range)
316 name: Optional name for the box (auto-generated if not provided)
317 """
318 if name is None:
319 name = self._generate_name("/box_")
320
321 # viser expects full extents, not half extents
322 full_extents = [h * 2 for h in half_extents]
323
324 if rotation_matrix is None:
325 rotation_matrix = np.eye(3)
326
327 wxyz = self._rotation_to_wxyz(rotation_matrix)
328
329 self.server.scene.add_box(
330 name=name,
331 dimensions=tuple(full_extents),
332 position=tuple(position),
333 wxyz=tuple(wxyz),
334 color=color[:3] if len(color) == 4 else color,
335 opacity=color[3] if len(color) == 4 else 1.0,
336 )
337
338 def add_cylinder(
339 self,
340 position: np.ndarray,
341 radius: float,
342 length: float,
343 rotation_matrix: Optional[np.ndarray] = None,
344 color=(0.8, 0.4, 0.2, 0.75),
345 name: Optional[str] = None,
346 ):
347 """Add a cylinder obstacle to the scene
348
349 Args:
350 position: 3D position [x, y, z]
351 radius: Cylinder radius
352 length: Cylinder length (height)
353 rotation_matrix: 3x3 rotation matrix (identity if None)
354 color: RGBA color tuple (0-1 range)
355 name: Optional name for the cylinder (auto-generated if not provided)
356 """
357 if name is None:
358 name = self._generate_name("/cylinder_")
359
360 if rotation_matrix is None:
361 rotation_matrix = np.eye(3)
362
363 wxyz = self._rotation_to_wxyz(rotation_matrix)
364
365 self.server.scene.add_cylinder(
366 name=name,
367 radius=radius,
368 height=length,
369 position=tuple(position),
370 wxyz=tuple(wxyz),
371 color=color[:3] if len(color) == 4 else color,
372 opacity=color[3] if len(color) == 4 else 1.0,
373 )
374
375 def add_grid(self, width: float = 2.0, height: float = 2.0, cell_size: float = 0.1):
376 """Add a grid to the scene
377
378 Args:
379 width: Grid width
380 height: Grid height
381 cell_size: Size of each grid cell
382 """
383 self.server.scene.add_grid(
384 "/grid", width=width, height=height, cell_size=cell_size
385 )
386
387 def add_point_cloud(
388 self,
389 points: np.ndarray,
390 color: Optional[np.ndarray] = None,
391 point_size: float = 0.01,
392 ):
393 """Add a point cloud to the scene
394
395 Args:
396 points: Array of shape (N, 3) containing point coordinates
397 color: Optional array of shape (N, 3) containing RGB colors for each point
398 """
399 if color is None:
400 color = np.ones((points.shape[0], 3)) * np.array([1, 0, 0])
401
402 self.server.scene.add_point_cloud(
403 "/point_cloud", points=points, colors=color, point_size=point_size
404 )
405
406 def _map_plan_config_to_urdf(self, plan_config: List[float]) -> List[float]:
407 """Map planning configuration to full URDF configuration
408
409 For robots like fetch that use a subset of joints in planning, this method
410 maps the planning DOFs to the correct positions in the full URDF configuration.
411
412 Args:
413 plan_config: Configuration from planner (e.g., 9 DOFs for fetch)
414
415 Returns:
416 Full URDF configuration with all joints
417 """
418
419 if self.robot_name == "panda":
420 # append gripper
421 plan_config.append(0.05)
422 return plan_config
423
424 if self.joint_mapping is None:
425 # No mapping needed, use config as-is
426 return plan_config
427
428 # Get all joint names from URDF
429
430 all_joints = [joint.name for joint in self.robot_urdf.actuated_joints]
431 n_total_joints = len(all_joints)
432
433 # Create full configuration with zeros (neutral positions)
434 full_config = [0.0] * n_total_joints
435
436 # Map planning joints to their positions in URDF
437 for planning_idx, joint_name in enumerate(self.joint_mapping):
438 if joint_name in all_joints:
439 urdf_idx = all_joints.index(joint_name)
440 if planning_idx < len(plan_config):
441 full_config[urdf_idx] = plan_config[planning_idx]
442
443 # if fetch, open gripper
444 if self.robot_name == "fetch":
445 full_config[-1] = 0.035
446 full_config[-2] = 0.035
447
448 return full_config
449
450 def visualize_trajectory(self, trajectory: np.ndarray):
451 """Visualize a robot trajectory with interactive controls
452
453 Args:
454 trajectory: Array of shape (timesteps, joints) containing joint configurations
455 """
456 if trajectory.ndim != 2:
457 raise ValueError(
458 f"Trajectory must be 2D array, got shape {trajectory.shape}"
459 )
460
461 if trajectory.shape[1] != self.dimension:
462 # Handle case where trajectory might have extra (gripper?) DOF
463 if trajectory.shape[1] == self.dimension + 1:
464 trajectory = trajectory[:, : self.dimension]
465 else:
466 raise ValueError(
467 f"Trajectory has {trajectory.shape[1]} DOFs but robot has {self.dimension} DOFs"
468 )
469
470 self._trajectory = trajectory
471
472 self._slider = self.server.gui.add_slider(
473 "Timestep", min=0, max=len(trajectory) - 1, step=1, initial_value=0
474 )
475
476 self._playing = self.server.gui.add_checkbox("Playing", initial_value=True)
477
478 self._start_time = time.time()
479
480 def visualization_step(self):
481 """Perform one step of the visualization update"""
482 if self._trajectory is None:
483 return
484
485 if self._playing.value:
486 elapsed = time.time() - self._start_time
487 progress = (elapsed % self._playtime) / self._playtime
488 self._slider.value = int(progress * (len(self._trajectory) - 1))
489
490 # Update robot configuration with consistent gripper DOF of 0.0
491 slider_idx = self._slider.value
492 self._update_robot_config(slider_idx, gripper_dof=0.0)
493
494 time.sleep(0.01)
495
496 def visualization_loop(self):
497 """Visualization loop"""
498 print(f"Visualization running on http://localhost:{self.server.port}")
499 print("Press Ctrl+C to stop...")
500 try:
501 while True:
502 self.visualization_step()
503 except KeyboardInterrupt:
504 print("\nVisualization stopped.")
505
506 def play_once(self, dt=0.1):
507 """
508 Play through the trajectory once at the specified speed
509
510 Args:
511 dt: Time delay between frames in seconds (default: 0.1)
512 """
513 if self._trajectory is None:
514 print("No trajectory loaded. Call visualize_trajectory() first.")
515 return
516
517 for i in range(len(self._trajectory)):
518 self._slider.value = i
519 self._update_robot_config(i, gripper_dof=0.0)
520 time.sleep(dt)
521
522 def play_until_key_pressed(self, key="any", dt=0.1):
523 """Play visualization until specified key is pressed
524
525 Args:
526 key: Key to wait for. Use 'any' to stop on any key press (default: 'any')
527 dt: Time delay between frames in seconds (default: 0.1)
528
529 Returns:
530 str: The key that was pressed
531 """
532 if self._trajectory is None:
533 print("No trajectory loaded. Call visualize_trajectory() first.")
534 return None
535
536 print(
537 f"Visualization running. Press {key if key != 'any' else 'any key'} to stop or click play button..."
538 )
539
540 pressed_key = [None]
541 stop_flag = threading.Event()
542
543 def wait_for_key():
544 try:
545 fd = sys.stdin.fileno()
546 old_settings = termios.tcgetattr(fd)
547 except:
548 # Not a TTY, skip key handling
549 return
550
551 try:
552 tty.setraw(fd)
553 while not stop_flag.is_set():
554 try:
555 if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
556 char = sys.stdin.read(1)
557 pressed_key[0] = char
558 if key == "any" or char == key:
559 stop_flag.set()
560 break
561 except:
562 break
563 time.sleep(0.01)
564 except Exception as e:
565 print(f"Key listener error: {e}")
566 finally:
567 try:
568 termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
569 except:
570 pass
571
572 # Start thread to wait for key
573 key_thread = threading.Thread(target=wait_for_key, daemon=True)
574 key_thread.start()
575
576 # Run visualization loop until key is pressed
577 # The play button controls auto-advancement of the slider
578 try:
579 frame_idx = 0
580 last_playing_state = False
581 last_slider_value = 0
582
583 while not stop_flag.is_set():
584 # Check if playing state just changed from False to True
585 if self._playing is not None:
586 current_playing_state = self._playing.value
587 if current_playing_state and not last_playing_state:
588 # Playing just activated, capture current slider position
589 if self._slider is not None:
590 frame_idx = self._slider.value
591 last_playing_state = current_playing_state
592
593 current_slider_value = (
594 self._slider.value if self._slider is not None else frame_idx
595 )
596 if current_slider_value != last_slider_value:
597 # User moved the slider, sync frame_idx to it
598 frame_idx = current_slider_value
599 last_slider_value = current_slider_value
600
601 # If playing is checked, auto-advance the slider
602 if self._playing is not None and self._playing.value:
603 if self._slider is not None:
604 self._slider.value = frame_idx
605 frame_idx = (frame_idx + 1) % len(self._trajectory)
606
607 # Always update robot position based on current slider value
608 if self._slider is not None:
609 current_idx = self._slider.value
610 else:
611 current_idx = frame_idx
612
613 self._update_robot_config(current_idx, gripper_dof=0.0)
614
615 time.sleep(dt)
616 finally:
617 stop_flag.set()
618
619 print(f"\nVisualization stopped. Key pressed: {pressed_key[0]}")
620 return pressed_key[0]
__init__(self, str robot_name, int robot_dimension, Optional[int] port=None)