# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import artist
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.mplot3d.art3d import Line3D, Text3D
import rofunc as rf
from rofunc.utils.robolab.coord.transform import check_rot_matrix
[docs]def set_axis(ax, data=None, labels=None, elev=45, azim=45, roll=0):
"""
Set the axis of the figure.
Example::
>>> import rofunc as rf
>>> import numpy as np
>>> fig = plt.figure()
>>> ax = fig.add_subplot(111, projection='3d')
>>> data = [np.array([0, 1, 2]), np.array([0, 1, 2]), np.array([0, 1, 2])]
>>> rf.visualab.set_axis(ax, data)
>>> plt.show()
:param ax: the axis of the figure
:param data: the data to be plotted, used for setting the range of the axis, should be a list 【X, Y, Z】
:param labels: the labels of the axis
:param elev: the elevation of the axis
:param azim: the azimuth of the axis
:param roll: the roll of the axis
"""
try:
ax.view_init(elev=elev, azim=azim, roll=roll)
except:
ax.view_init(elev=elev, azim=azim)
# ax.set_aspect('equal', 'box')
ax.set_box_aspect([1, 1, 1])
if labels is None:
labels = ['x', 'y', 'z']
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])
if data is not None:
X, Y, Z = data
max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max() / 1.2
mid_x = (X.max() + X.min()) * 0.5
mid_y = (Y.max() + Y.min()) * 0.5
mid_z = (Z.max() + Z.min()) * 0.5
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
[docs]def save_img(fig, save_dir, fig_name=None, dpi=300, transparent=False, format=None):
"""
Save the figure to the specified directory.
:param fig: the figure to be saved
:param save_dir: the directory to save the figure
:param fig_name: the name of the figure, if None, the name will be fig_{nb_files}
:param dpi: the dpi of the figure
:param transparent: transparent or not
:param format: the format of the figure, default ['eps', 'png'] together with the same name
:return:
"""
if format is None:
format = ['eps', 'png']
rf.oslab.create_dir(save_dir)
if fig_name is None:
nb_files = len(os.listdir(save_dir))
fig_name = 'fig_{}'.format(nb_files)
for f in format:
full_fig_name = '{}.{}'.format(fig_name, f)
save_path = os.path.join(save_dir, full_fig_name)
fig.savefig(save_path, dpi=dpi, transparent=transparent, format=f)
[docs]class Frame(artist.Artist):
"""A Matplotlib artist that displays a frame represented by its basis.
Parameters
----------
A2B : array-like, shape (4, 4)
Transform from frame A to frame B
label : str, optional (default: None)
Name of the frame
s : float, optional (default: 1)
Length of basis vectors
draw_label_indicator : bool, optional (default: True)
Controls whether the line from the frame origin to frame label is
drawn.
Other arguments except 'c' and 'color' are passed on to Line3D.
"""
def __init__(self, A2B, label=None, s=1.0, **kwargs):
super(Frame, self).__init__()
if "c" in kwargs:
kwargs.pop("c")
if "color" in kwargs:
kwargs.pop("color")
self.draw_label_indicator = kwargs.pop("draw_label_indicator", True)
self.s = s
self.x_axis = Line3D([], [], [], color="r", **kwargs)
self.y_axis = Line3D([], [], [], color="g", **kwargs)
self.z_axis = Line3D([], [], [], color="b", **kwargs)
self.draw_label = label is not None
self.label = label
if self.draw_label:
if self.draw_label_indicator:
self.label_indicator = Line3D([], [], [], color="k", **kwargs)
self.label_text = Text3D(0, 0, 0, text="", zdir="x")
self.set_data(A2B, label)
[docs] def set_data(self, A2B, label=None):
"""Set the transformation data.
Parameters
----------
A2B : array-like, shape (4, 4)
Transform from frame A to frame B
label : str, optional (default: None)
Name of the frame
"""
R = A2B[:3, :3]
p = A2B[:3, 3]
for d, b in enumerate([self.x_axis, self.y_axis, self.z_axis]):
b.set_data(np.array([p[0], p[0] + self.s * R[0, d]]),
np.array([p[1], p[1] + self.s * R[1, d]]))
b.set_3d_properties(np.array([p[2], p[2] + self.s * R[2, d]]))
if self.draw_label:
if label is None:
label = self.label
label_pos = p + 0.5 * self.s * (R[:, 0] + R[:, 1] + R[:, 2])
if self.draw_label_indicator:
self.label_indicator.set_data(
np.array([p[0], label_pos[0]]),
np.array([p[1], label_pos[1]]))
self.label_indicator.set_3d_properties(
np.array([p[2], label_pos[2]]))
self.label_text.set_text(label)
self.label_text.set_position([label_pos[0], label_pos[1]])
self.label_text.set_3d_properties(label_pos[2], zdir="x")
[docs] @artist.allow_rasterization
def draw(self, renderer, *args, **kwargs):
"""Draw the artist."""
for b in [self.x_axis, self.y_axis, self.z_axis]:
b.draw(renderer, *args, **kwargs)
if self.draw_label:
if self.draw_label_indicator:
self.label_indicator.draw(renderer, *args, **kwargs)
self.label_text.draw(renderer, *args, **kwargs)
super(Frame, self).draw(renderer, *args, **kwargs)
[docs] def add_frame(self, axis):
"""Add the frame to a 3D axis."""
for b in [self.x_axis, self.y_axis, self.z_axis]:
axis.add_line(b)
if self.draw_label:
if self.draw_label_indicator:
axis.add_line(self.label_indicator)
axis._add_text(self.label_text)
[docs]def make_3d_axis(ax_s, pos=111, unit=None, n_ticks=5):
"""
Generate new 3D axis for plotting the basis.
:param ax_s: Scaling of the new matplotlib 3d axis
:param pos: Position indicator (nrows, ncols, plot_number)
:param unit: Unit of axes. For example, 'm', 'cm', 'km', ... The unit will be shown in the axis label, for example,
as 'X [m]'.
:param n_ticks: Number of ticks on each axis
:return: New axis
"""
try:
ax = plt.subplot(pos, projection="3d", aspect="equal")
except NotImplementedError:
# HACK: workaround for bug in new matplotlib versions (ca. 3.02):
# "It is not currently possible to manually set the aspect"
ax = plt.subplot(pos, projection="3d")
if unit is None:
xlabel = "X"
ylabel = "Y"
zlabel = "Z"
else:
xlabel = "X [%s]" % unit
ylabel = "Y [%s]" % unit
zlabel = "Z [%s]" % unit
plt.setp(
ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), zlim=(-ax_s, ax_s),
xlabel=xlabel, ylabel=ylabel, zlabel=zlabel)
ax.xaxis.set_major_locator(MaxNLocator(n_ticks))
ax.yaxis.set_major_locator(MaxNLocator(n_ticks))
ax.zaxis.set_major_locator(MaxNLocator(n_ticks))
try:
ax.xaxis.pane.set_color("white")
ax.yaxis.pane.set_color("white")
ax.zaxis.pane.set_color("white")
except AttributeError: # pragma: no cover
# fallback for older versions of matplotlib, deprecated since v3.1
ax.w_xaxis.pane.set_color("white")
ax.w_yaxis.pane.set_color("white")
ax.w_zaxis.pane.set_color("white")
return ax
[docs]def plot_basis(ax=None, R=None, p=np.zeros(3), s=1.0, ax_s=1,
strict_check=True, **kwargs):
"""
Plot basis of a rotation matrix.
:param ax: the axis to plot the basis
:param R: rotation matrix, each column contains a basis vector
:param p: offset from the origin
:param s: scaling of the frame that will be drawn
:param ax_s: scaling of the new matplotlib 3d axis
:param strict_check: raise a ValueError if the rotation matrix is not numerically close enough to a real rotation
matrix. Otherwise, we print a warning.
:param kwargs: additional arguments for the plotting functions, e.g. alpha
:return: ax
"""
if ax is None:
ax = make_3d_axis(ax_s)
if R is None:
R = np.eye(3)
R = check_rot_matrix(R, strict_check=strict_check)
A2B = np.eye(4)
A2B[:3, :3] = R
A2B[:3, 3] = p
frame = Frame(A2B, s=s, **kwargs)
frame.add_frame(ax)
return ax