Source code for rofunc.learning.ml.visualize

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import rofunc as rf

color_list = ['steelblue', 'orangered', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']


[docs]def hmm_plot(nb_dim, demos_xdx_f, model, task_params=None): if nb_dim == 2: fig = hmm_plot2d(demos_xdx_f, model, task_params) elif nb_dim > 2: fig = hmm_plot_3d(demos_xdx_f, model, scale=0.1, task_params=task_params) else: raise Exception('Dimension is less than 2, cannot plot') return fig
[docs]def poe_plot(nb_dim, mod_list, prod, demos_x, show_demo_idx, task_params=None): if nb_dim == 2: fig = poe_plot2d(mod_list, prod, demos_x, show_demo_idx, task_params) elif nb_dim > 2: fig = poe_plot_3d(mod_list, prod, demos_x, show_demo_idx, task_params) else: raise Exception('Dimension is less than 2, cannot plot') return fig
[docs]def gen_plot(nb_dim, xi, prod, demos_x, show_demo_idx, title='Trajectory reproduction', label='reproduced line'): if nb_dim == 2: fig = generate_plot_2d(xi, prod, demos_x, show_demo_idx, title=title, label=label) elif nb_dim > 2: fig = generate_plot_3d(xi, prod, demos_x, show_demo_idx, title=title, label=label) else: raise Exception('Dimension is less than 2, cannot plot') return fig
[docs]def hmm_plot2d(demos_xdx_f, model, task_params=None): P = len(demos_xdx_f[0][0]) fig, ax = plt.subplots(ncols=P, nrows=1) fig.set_size_inches(4 * P, 6) for p in range(P): if 'frame_names' in task_params: ax[p].set_title('{} frame'.format(task_params['frame_names'][p])) else: ax[p].set_title('Frame {}'.format(p + 1)) for d in demos_xdx_f: ax[p].plot(d[:, p, 0], d[:, p, 1]) rf.visualab.gmm_plot(model.mu, model.sigma, ax=ax[p], dim=[4 * p, 4 * p + 1], color=color_list[p], alpha=0.1) plt.tight_layout() return fig
[docs]def hmm_plot_3d(demos_xdx_f, model, scale=1, task_params=None): P = len(demos_xdx_f[0][0]) nb_dim_deriv = len(demos_xdx_f[0][0][0]) fig = plt.figure(figsize=(4, 4)) fig.set_size_inches((4 * P, 6)) for p in range(P): ax = fig.add_subplot(1, P, p + 1, projection='3d', fc='white') if 'frame_names' in task_params: ax.set_title('{} frame'.format(task_params['frame_names'][p])) else: ax.set_title('Frame {}'.format(p + 1)) for d in demos_xdx_f: ax.plot(d[:, p, 0], d[:, p, 1], d[:, p, 2]) rf.visualab.gmm_plot(model.mu, model.sigma, ax=ax, dim=[nb_dim_deriv * p, nb_dim_deriv * p + 1, nb_dim_deriv * p + 2], color=color_list[p], scale=scale, alpha=0.1) rf.visualab.set_axis(ax, data=[d[:, p, 0], d[:, p, 1], d[:, p, 2]]) return fig
[docs]def poe_plot2d(mod_list, prod, demos_x, demo_idx, task_params): P = len(mod_list) fig, ax = plt.subplots(ncols=P + 1, nrows=1) fig.set_size_inches((4 * (P + 1), 6)) for i in ax: i.set_aspect('equal') for p in range(P): if task_params is not None: ax[p].set_title('{} frame'.format(task_params['frame_names'][p])) else: ax[p].set_title('Model (frame %d)' % (p + 1)) rf.visualab.gmm_plot(mod_list[p].mu, mod_list[p].sigma, swap=True, ax=ax[p], dim=[0, 1], color=color_list[p], alpha=0.3) ax[P].set_title('Product of Experts (PoE)') for p in range(P): rf.visualab.gmm_plot(mod_list[p].mu, mod_list[p].sigma, swap=True, ax=ax[P], dim=[0, 1], color=color_list[p], alpha=0.3) rf.visualab.gmm_plot(prod.mu, prod.sigma, swap=True, ax=ax[P], dim=[0, 1], color='gold', alpha=0.3) ax[P].plot(demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], color="b") patches = [mpatches.Patch(color=color_list[p], label='Model (frame %d)' % (p + 1)) for p in range(P)] patches.append(mpatches.Patch(color='gold', label='Product')) plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) return fig
[docs]def poe_plot_3d(mod_list, prod, demos_x, demo_idx, task_params=None): P = len(mod_list) fig = plt.figure(figsize=(4, 4)) fig.set_size_inches((4 * (P + 1), 6)) for p in range(P): ax = fig.add_subplot(1, P + 1, p + 1, projection='3d', fc='white') if task_params is not None: ax.set_title('{} frame'.format(task_params['frame_names'][p])) else: ax.set_title('Model (frame %d)' % (p + 1)) rf.visualab.gmm_plot(mod_list[p].mu, mod_list[p].sigma, swap=True, ax=ax, dim=[0, 1, 2], color=color_list[p], alpha=0.05) rf.visualab.set_axis(ax, data=[demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], demos_x[demo_idx][:, 2]]) ax = fig.add_subplot(1, P + 1, P + 1, projection='3d', fc='white') ax.set_title('Product of Experts (PoE)') for p in range(P): rf.visualab.gmm_plot(mod_list[p].mu, mod_list[p].sigma, swap=True, ax=ax, dim=[0, 1, 2], color=color_list[p], alpha=0.05) rf.visualab.gmm_plot(prod.mu, prod.sigma, swap=True, ax=ax, dim=[0, 1, 2], color='gold', alpha=0.05) ax.plot(demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], demos_x[demo_idx][:, 2], color="b") rf.visualab.set_axis(ax, data=[demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], demos_x[demo_idx][:, 2]]) patches = [mpatches.Patch(color=color_list[p], label='Model (frame %d)' % (p + 1)) for p in range(P)] patches.append(mpatches.Patch(color='gold', label='Product')) plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) return fig
[docs]def generate_plot_2d(xi, prod, demos_x, demo_idx, title='Trajectory reproduction', label='reproduced line'): fig = plt.figure() plt.title(title) rf.visualab.gmm_plot(prod.mu, prod.sigma, swap=True, dim=[0, 1], color='gold', alpha=0.5) plt.plot(demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], 'k--', lw=2, label='demo line') plt.plot(xi[:, 0], xi[:, 1], color='r', lw=2, label=label) plt.axis('equal') plt.legend() return fig
[docs]def generate_plot_3d(xi, prod, demos_x, demo_idx, scale=0.01, plot_gmm=False, plot_ori=True, title='Trajectory reproduction', label='reproduced line'): fig = plt.figure(figsize=(4, 4)) ax = fig.add_subplot(111, projection='3d', fc='white') ax.set_title(title) if plot_gmm: rf.visualab.gmm_plot(prod.mu, prod.sigma, dim=[0, 1, 2], color='gold', scale=0.01, ax=ax) ax.plot(demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], demos_x[demo_idx][:, 2], 'k--', lw=2, label='demo line') ax.plot(xi[:, 0], xi[:, 1], xi[:, 2], color='r', lw=2, label=label) rf.visualab.set_axis(ax, data=[demos_x[demo_idx][:, 0], demos_x[demo_idx][:, 1], demos_x[demo_idx][:, 2]]) plt.legend() if plot_ori and xi.shape[1] == 7: t = np.arange(len(xi)) plt.figure() plt.subplot(2, 2, 1) plt.plot(np.arange(len(demos_x[demo_idx][:, 3])), demos_x[demo_idx][:, 3], 'k--', lw=2, label='demo line') plt.plot(t, xi[:, 3], color='r', lw=2, label=label) plt.title('w-t') plt.subplot(2, 2, 2) plt.plot(np.arange(len(demos_x[demo_idx][:, 4])), demos_x[demo_idx][:, 4], 'k--', lw=2, label='demo line') plt.plot(t, xi[:, 4], color='r', lw=2, label=label) plt.title('x-t') plt.subplot(2, 2, 3) plt.plot(np.arange(len(demos_x[demo_idx][:, 5])), demos_x[demo_idx][:, 5], 'k--', lw=2, label='demo line') plt.plot(t, xi[:, 5], color='r', lw=2, label=label) plt.title('y-t') plt.subplot(2, 2, 4) plt.plot(np.arange(len(demos_x[demo_idx][:, 6])), demos_x[demo_idx][:, 6], 'k--', lw=2, label='demo line') plt.plot(t, xi[:, 6], color='r', lw=2, label=label) plt.title('z-t') return fig