Source code for rofunc.planning_control.lqt.visualize

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib.pyplot as plt
import numpy as np
import rofunc as rf


# def plot_2d(cfg, x_hat_l, x_hat_r, idx_slices, tl, via_point_l, via_point_r):
#     # TODO: check
#     plt.figure()
#     plt.title("2D Trajectory")
#     plt.scatter(x_hat_l[0, 0], x_hat_l[0, 1], c='blue', s=100)
#     plt.scatter(x_hat_r[0, 0], x_hat_r[0, 1], c='green', s=100)
#     for slice_t in idx_slices:
#         plt.scatter(param["muQ_l"][slice_t][0], param["muQ_l"][slice_t][1], c='red', s=100)
#         plt.scatter(param["muQ_r"][slice_t][0], param["muQ_r"][slice_t][1], c='orange', s=100)
#         plt.plot([param["muQ_l"][slice_t][0], param["muQ_r"][slice_t][0]],
#                  [param["muQ_l"][slice_t][1], param["muQ_r"][slice_t][1]], linewidth=2, color='black')
#     plt.plot(x_hat_l[:, 0], x_hat_l[:, 1], c='blue')
#     plt.plot(x_hat_r[:, 0], x_hat_r[:, 1], c='green')
#     plt.axis("off")
#     plt.gca().set_aspect('equal', adjustable='box')
#
#     fig, axs = plt.subplots(3, 1)
#     for i, t in enumerate(tl):
#         axs[0].scatter(t, param["muQ_l"][idx_slices[i]][0], c='red')
#         axs[0].scatter(t, param["muQ_r"][idx_slices[i]][0], c='orange')
#     axs[0].plot(x_hat_l[:, 0], c='blue')
#     axs[0].plot(x_hat_r[:, 0], c='green')
#     axs[0].set_ylabel("$x_1$")
#     axs[0].set_xticks([0, cfg.nbData])
#     axs[0].set_xticklabels(["0", "T"])
#
#     for i, t in enumerate(tl):
#         axs[1].scatter(t, param["muQ_l"][idx_slices[i]][1], c='red')
#         axs[1].scatter(t, param["muQ_r"][idx_slices[i]][1], c='orange')
#     axs[1].plot(x_hat_l[:, 1], c='blue')
#     axs[1].plot(x_hat_r[:, 1], c='green')
#     axs[1].set_ylabel("$x_2$")
#     axs[1].set_xlabel("$t$")
#     axs[1].set_xticks([0, cfg.nbData])
#     axs[1].set_xticklabels(["0", "T"])
#
#     dis_lst = []
#     for i in range(len(x_hat_l)):
#         dis_lst.append(np.sqrt(np.sum(np.square(x_hat_l[i, :2] - x_hat_r[i, :2]))))
#
#     dis_lst = np.array(dis_lst)
#     timestep = np.arange(len(dis_lst))
#     axs[2].plot(timestep, dis_lst)
#     axs[2].set_ylabel("traj_dis")
#     axs[2].set_xlabel("$t$")
#     axs[2].set_xticks([0, cfg.nbData])
#     axs[2].set_xticklabels(["0", "T"])
#
#     dis_lst = []
#     via_point_l = np.array(via_point_l)
#     via_point_r = np.array(via_point_r)
#     for i in range(len(via_point_l)):
#         dis_lst.append(np.sqrt(np.sum(np.square(via_point_l[i, :2] - via_point_r[i, :2]))))
#
#     dis_lst = np.array(dis_lst)
#     timestep = np.arange(len(dis_lst))
#     axs[3].plot(timestep, dis_lst)
#
#     plt.show()


[docs]def plot_3d_uni(x_hat, muQ=None, idx_slices=None, ori=False, save=False, save_file_name=None, g_ax=None, title=None, legend=None, for_test=False): if g_ax is None: fig = plt.figure(figsize=(4, 4)) ax = fig.add_subplot(111, projection='3d', fc='white') else: ax = g_ax if muQ is not None and idx_slices is not None: for slice_t in idx_slices: ax.scatter(muQ[slice_t][0], muQ[slice_t][1], muQ[slice_t][2], c='red', s=10) if not isinstance(x_hat, list): if len(x_hat.shape) == 2: x_hat = np.expand_dims(x_hat, axis=0) title = 'Unimanual trajectory' if title is None else title rf.visualab.traj_plot(x_hat, legend=legend, title=title, mode='3d', ori=ori, g_ax=ax) if save: assert save_file_name is not None np.save(save_file_name, np.array(x_hat)) if g_ax is None and not for_test: plt.show()
[docs]def plot_3d_bi(x_hat_l, x_hat_r, muQ_l=None, muQ_r=None, idx_slices=None, ori=False, save=False, save_file_name=None, g_ax=None, title=None, legend_lst=None, for_test=False): if g_ax is None: fig = plt.figure(figsize=(4, 4)) ax = fig.add_subplot(111, projection='3d', fc='white') else: ax = g_ax if muQ_l is not None and muQ_r is not None and idx_slices is not None: for slice_t in idx_slices: ax.scatter(muQ_l[slice_t][0], muQ_l[slice_t][1], muQ_l[slice_t][2], c='red', s=10) ax.scatter(muQ_r[slice_t][0], muQ_r[slice_t][1], muQ_r[slice_t][2], c='orange', s=10) if not isinstance(x_hat_l, list): if len(x_hat_l.shape) == 2: x_hat_l = np.expand_dims(x_hat_l, axis=0) x_hat_r = np.expand_dims(x_hat_r, axis=0) title = 'Bimanual trajectory' if title is None else title legend_l = 'left arm' if legend_lst is None else legend_lst[0] legend_r = 'right arm' if legend_lst is None else legend_lst[1] rf.visualab.traj_plot(x_hat_l, title=title, legend=legend_l, mode='3d', ori=ori, g_ax=ax) rf.visualab.traj_plot(x_hat_r, legend=legend_r, mode='3d', ori=ori, g_ax=ax) if save: assert save_file_name is not None np.save(save_file_name[0], np.array(x_hat_l)) np.save(save_file_name[1], np.array(x_hat_r)) if g_ax is None and not for_test: plt.show()