import sys
sys.setrecursionlimit(100000)
import os
import numpy as np
import torch
np.set_printoptions(threshold=np.inf)
import trimesh
import mesh_to_sdf
import skimage
import rofunc as rf
from tqdm import tqdm
from rofunc.utils.robolab.rdf import utils
[docs]class RDF:
def __init__(self, args):
"""
Use Bernstein Polynomial to represent the SDF of the robot
"""
self.args = args
self.n_func = args.n_func
self.domain_min = args.domain_min
self.domain_max = args.domain_max
self.device = args.device
self.robot_asset_root = args.robot_asset_root
self.robot_model_path = os.path.join(self.robot_asset_root, args.robot_asset_name)
self.save_mesh_dict = args.save_mesh_dict
# Build the robot from the URDF/MJCF file
self.robot = rf.robolab.RobotModel(self.robot_model_path, solve_engine="pytorch_kinematics", device=self.device,
verbose=False)
assert os.path.exists(self.robot.mesh_dir), "Please organize the robot meshes in the 'meshes' folder!"
self.link_list = self.robot.get_link_list()
self.link_mesh_map = self.robot.get_link_mesh_map()
self.link_mesh_name_map = {k: os.path.basename(v).split(".")[0] for k, v in self.link_mesh_map.items()}
def _binomial_coefficient(self, n, k):
return torch.exp(torch.lgamma(n + 1) - torch.lgamma(k + 1) - torch.lgamma(n - k + 1))
def _build_bernstein_t(self, t, use_derivative=False):
# t is normalized to [0,1]
t = torch.clamp(t, min=1e-4, max=1 - 1e-4)
n = self.n_func - 1
i = torch.arange(self.n_func, device=self.device)
comb = self._binomial_coefficient(torch.tensor(n, device=self.device), i)
phi = comb * (1 - t).unsqueeze(-1) ** (n - i) * t.unsqueeze(-1) ** i
if not use_derivative:
return phi.float(), None
else:
dphi = -comb * (n - i) * (1 - t).unsqueeze(-1) ** (n - i - 1) * t.unsqueeze(-1) ** i + comb * i * (
1 - t).unsqueeze(-1) ** (n - i) * t.unsqueeze(-1) ** (i - 1)
dphi = torch.clamp(dphi, min=-1e4, max=1e4)
return phi.float(), dphi.float()
def _build_basis_function_from_points(self, points, use_derivative=False):
N = len(points)
points = ((points - self.domain_min) / (self.domain_max - self.domain_min)).reshape(-1)
phi, d_phi = self._build_bernstein_t(points, use_derivative)
phi = phi.reshape(N, 3, self.n_func)
phi_x = phi[:, 0, :]
phi_y = phi[:, 1, :]
phi_z = phi[:, 2, :]
phi_xy = torch.einsum("ij,ik->ijk", phi_x, phi_y).view(-1, self.n_func ** 2)
phi_xyz = torch.einsum("ij,ik->ijk", phi_xy, phi_z).view(-1, self.n_func ** 3)
if not use_derivative:
return phi_xyz, None
else:
d_phi = d_phi.reshape(N, 3, self.n_func)
d_phi_x_1D = d_phi[:, 0, :]
d_phi_y_1D = d_phi[:, 1, :]
d_phi_z_1D = d_phi[:, 2, :]
d_phi_x = torch.einsum("ij,ik->ijk",
torch.einsum("ij,ik->ijk", d_phi_x_1D, phi_y).view(-1, self.n_func ** 2),
phi_z).view(-1, self.n_func ** 3)
d_phi_y = torch.einsum("ij,ik->ijk",
torch.einsum("ij,ik->ijk", phi_x, d_phi_y_1D).view(-1, self.n_func ** 2),
phi_z).view(-1, self.n_func ** 3)
d_phi_z = torch.einsum("ij,ik->ijk", phi_xy, d_phi_z_1D).view(-1, self.n_func ** 3)
d_phi_xyz = torch.cat((d_phi_x.unsqueeze(-1), d_phi_y.unsqueeze(-1), d_phi_z.unsqueeze(-1)), dim=-1)
return phi_xyz, d_phi_xyz
def _sample_sdf_points(self, mesh, mesh_name):
print(f'Sampling points for mesh {mesh_name}...')
center = mesh.bounding_box.centroid
scale = np.max(np.linalg.norm(mesh.vertices - center, axis=1))
# sample points near surface (as same as deepSDF)
near_points, near_sdf = mesh_to_sdf.sample_sdf_near_surface(mesh,
number_of_points=500000,
surface_point_method='scan',
sign_method='normal',
scan_count=100,
scan_resolution=400,
sample_point_count=10000000,
normal_sample_count=100,
min_size=0.0,
return_gradients=False)
# # sample points randomly within the bounding box [-1,1]
random_points = np.random.rand(500000, 3) * 2.0 - 1.0
random_sdf = mesh_to_sdf.mesh_to_sdf(mesh,
random_points,
surface_point_method='scan',
sign_method='normal',
bounding_radius=None,
scan_count=100,
scan_resolution=400,
sample_point_count=10000000,
normal_sample_count=100)
# save data
data = {
'near_points': near_points,
'near_sdf': near_sdf,
'random_points': random_points,
'random_sdf': random_sdf,
'center': center,
'scale': scale
}
save_path = os.path.join(self.robot_asset_root, 'rdf/sdf_points')
rf.oslab.create_dir(save_path)
np.save(os.path.join(save_path, f'voxel_128_{mesh_name}.npy'), data)
print(f'Sampling points for mesh {mesh_name} finished!')
return data
[docs] def train(self):
mesh_files = rf.oslab.list_absl_path(self.robot.mesh_dir, recursive=True, suffix='.stl')
mesh_files2 = rf.oslab.list_absl_path(self.robot.mesh_dir, recursive=True, suffix='.STL')
mesh_files = mesh_files + mesh_files2
mesh_dict = {}
# sample points for each mesh
if self.args.sampled_points:
save_path = os.path.join(self.robot_asset_root, 'rdf/sdf_points')
rf.oslab.create_dir(save_path)
if self.args.parallel:
import multiprocessing
pool = multiprocessing.Pool(processes=12)
data_list = pool.map(job, [(mf, mf.split('/')[-1].split('.')[0], save_path) for i, mf in
enumerate(mesh_files) if not os.path.exists(
os.path.join(save_path, f'voxel_128_{mf.split("/")[-1].split(".")[0]}.npy'))])
for data in data_list:
mesh_name = data['mesh_name']
np.save(os.path.join(save_path, f'voxel_128_{mesh_name}.npy'), data)
else:
for i, mf in enumerate(tqdm(mesh_files)):
mesh_name = mf.split('/')[-1].split('.')[0]
if os.path.exists(os.path.join(save_path, f'voxel_128_{mesh_name}.npy')):
continue
data = sample_sdf_points(mf, mesh_name, save_path)
np.save(os.path.join(save_path, f'voxel_128_{mesh_name}.npy'), data)
def train_single_mesh(mf, i, data):
mesh_name = mf.split('/')[-1].split('.')[0]
print(f'Mesh {mesh_name} start processing...')
mesh = trimesh.load(mf)
offset = mesh.bounding_box.centroid
scale = np.max(np.linalg.norm(mesh.vertices - offset, axis=1))
point_near_data = data['near_points']
sdf_near_data = data['near_sdf']
point_random_data = data['random_points']
sdf_random_data = data['random_sdf']
sdf_random_data[sdf_random_data < -1] = -sdf_random_data[sdf_random_data < -1]
wb = torch.zeros(self.n_func ** 3).float().to(self.device)
batch_size = (torch.eye(self.n_func ** 3) / 1e-4).float().to(self.device)
# loss_list = []
for iter in range(self.args.train_epoch):
choice_near = np.random.choice(len(point_near_data), 1024, replace=False)
p_near, sdf_near = torch.from_numpy(point_near_data[choice_near]).float().to(
self.device), torch.from_numpy(sdf_near_data[choice_near]).float().to(self.device)
choice_random = np.random.choice(len(point_random_data), 256, replace=False)
p_random, sdf_random = torch.from_numpy(point_random_data[choice_random]).float().to(
self.device), torch.from_numpy(sdf_random_data[choice_random]).float().to(self.device)
p = torch.cat([p_near, p_random], dim=0)
sdf = torch.cat([sdf_near, sdf_random], dim=0)
phi_xyz, _ = self._build_basis_function_from_points(p.float().to(self.device),
use_derivative=False)
K = torch.matmul(batch_size, phi_xyz.T).matmul(torch.linalg.inv(
(torch.eye(len(p)).float().to(self.device) + torch.matmul(torch.matmul(phi_xyz, batch_size),
phi_xyz.T))))
batch_size -= torch.matmul(K, phi_xyz).matmul(batch_size)
delta_wb = torch.matmul(K, (sdf - torch.matmul(phi_xyz, wb)).squeeze())
# loss = torch.nn.functional.mse_loss(torch.matmul(phi_xyz,wb).squeeze(), sdf, reduction='mean').item()
# loss_list.append(loss)
wb += delta_wb
print(f'Mesh {mesh_name} finished!')
mesh_dict_single = {
'i': i,
'mesh_name': mesh_name,
'weights': wb,
'offset': torch.from_numpy(offset),
'scale': scale,
}
return mesh_dict_single
# train the model for each mesh
for i, mf in enumerate(tqdm(mesh_files)):
mesh_name = mf.split('/')[-1].split('.')[0]
sampled_point_data = np.load(f'{self.robot_asset_root}/rdf/sdf_points/voxel_128_{mesh_name}.npy',
allow_pickle=True).item()
res = train_single_mesh(mf, i, sampled_point_data)
mesh_dict[res['mesh_name']] = res
self.mesh_dict = mesh_dict
if self.save_mesh_dict:
rdf_model_path = os.path.join(self.robot_asset_root, 'rdf', 'BP')
rf.oslab.create_dir(rdf_model_path)
torch.save(mesh_dict, f'{rdf_model_path}/BP_{self.n_func}.pt') # save the robot sdf model
print(f'{rdf_model_path}/BP_{self.n_func}.pt model saved!')
[docs] def sdf_to_mesh(self, model, nbData, use_derivative=False):
verts_list, faces_list, mesh_name_list = [], [], []
for i, k in enumerate(model.keys()):
mesh_dict = model[k]
mesh_name = mesh_dict['mesh_name']
print(f'{mesh_name}')
mesh_name_list.append(mesh_name)
weights = mesh_dict['weights'].to(self.device)
domain = torch.linspace(self.domain_min, self.domain_max, nbData).to(self.device)
grid_x, grid_y, grid_z = torch.meshgrid(domain, domain, domain)
grid_x, grid_y, grid_z = grid_x.reshape(-1, 1), grid_y.reshape(-1, 1), grid_z.reshape(-1, 1)
p = torch.cat([grid_x, grid_y, grid_z], dim=1).float().to(self.device)
# split data to deal with memory issues
p_split = torch.split(p, 10000, dim=0)
d = []
for p_s in p_split:
phi_p, d_phi_p = self._build_basis_function_from_points(p_s, use_derivative)
d_s = torch.matmul(phi_p, weights)
d.append(d_s)
d = torch.cat(d, dim=0)
verts, faces, normals, values = skimage.measure.marching_cubes(
d.view(nbData, nbData, nbData).detach().cpu().numpy(), level=0.0,
spacing=np.array([(self.domain_max - self.domain_min) / nbData] * 3)
)
verts = verts - [1, 1, 1]
verts_list.append(verts)
faces_list.append(faces)
return verts_list, faces_list, mesh_name_list
[docs] def create_surface_mesh(self, model, nbData, vis=False, save_mesh_name=None):
verts_list, faces_list, mesh_name_list = self.sdf_to_mesh(model, nbData)
for verts, faces, mesh_name in zip(verts_list, faces_list, mesh_name_list):
rec_mesh = trimesh.Trimesh(verts, faces)
if vis:
rec_mesh.show()
if save_mesh_name is not None:
save_path = os.path.join(self.robot_asset_root, 'rdf', "output_meshes")
rf.oslab.create_dir(save_path)
trimesh.exchange.export.export_mesh(rec_mesh,
os.path.join(save_path, f"{save_mesh_name}_{mesh_name}.stl"))
[docs] def get_whole_body_sdf_batch(self, points, joint_value, model, base_trans=None, use_derivative=True,
used_links=None):
batch_size = len(joint_value)
N = len(points)
if used_links is None:
used_links = self.robot.get_link_list()
used_links = [link for link in used_links if link in self.link_mesh_name_map]
K = len(used_links)
offset = torch.cat([model[self.link_mesh_name_map[i]]['offset'].unsqueeze(0) for i in used_links if
i in self.link_mesh_name_map], dim=0).to(self.device)
offset = offset.unsqueeze(0).expand(batch_size, K, 3).reshape(batch_size * K, 3).float()
scale = torch.tensor([model[self.link_mesh_name_map[i]]['scale'] for i in used_links if
i in self.link_mesh_name_map], device=self.device)
scale = scale.unsqueeze(0).expand(batch_size, K).reshape(batch_size * K).float()
trans_dict = self.robot.get_trans_dict(joint_value, base_trans)
trans_lists = np.array([val.detach().cpu().numpy() for key, val in trans_dict.items() if key in used_links])
trans_lists = torch.tensor(trans_lists).reshape((K, batch_size, 4, 4)).to(self.device)
fk_trans = torch.cat([t.unsqueeze(1) for t in trans_lists], dim=1)[:, :, :, :].reshape(-1, 4,
4) # batch_size,K,4,4
x_robot_frame_batch = utils.transform_points(points.float(), torch.linalg.inv(fk_trans).float(),
device=self.device) # batch_size*K,N,3
x_robot_frame_batch_scaled = x_robot_frame_batch - offset.unsqueeze(1)
x_robot_frame_batch_scaled = x_robot_frame_batch_scaled / scale.unsqueeze(-1).unsqueeze(-1) # batch_size*K,N,3
x_bounded = torch.where(x_robot_frame_batch_scaled > 1.0 - 1e-2, 1.0 - 1e-2, x_robot_frame_batch_scaled)
x_bounded = torch.where(x_bounded < -1.0 + 1e-2, -1.0 + 1e-2, x_bounded)
res_x = x_robot_frame_batch_scaled - x_bounded
if not use_derivative:
phi, _ = self._build_basis_function_from_points(x_bounded.reshape(batch_size * K * N, 3),
use_derivative=False)
phi = phi.reshape(batch_size, K, N, -1).transpose(0, 1).reshape(K, batch_size * N, -1) # K,batch_size*N,-1
weights_near = torch.cat([model[self.link_mesh_name_map[i]]['weights'].unsqueeze(0) for i in used_links if
i in self.link_mesh_name_map], dim=0).to(self.device)
# sdf
sdf = torch.einsum('ijk,ik->ij', phi, weights_near).reshape(K, batch_size, N).transpose(0, 1).reshape(
batch_size * K,
N) # batch_size,K,N
sdf = sdf + res_x.norm(dim=-1)
sdf = sdf.reshape(batch_size, K, N)
sdf = sdf * scale.reshape(batch_size, K).unsqueeze(-1)
sdf_value, idx = sdf.min(dim=1)
return sdf_value, None
else:
phi, dphi = self._build_basis_function_from_points(x_bounded.reshape(batch_size * K * N, 3),
use_derivative=True)
phi_cat = torch.cat([phi.unsqueeze(-1), dphi], dim=-1)
phi_cat = phi_cat.reshape(batch_size, K, N, -1, 4).transpose(0, 1).reshape(K, batch_size * N, -1,
4) # K,batch_size*N,-1,4
weights_near = torch.cat([model[self.link_mesh_name_map[i]]['weights'].unsqueeze(0) for i in used_links],
dim=0).to(self.device)
output = torch.einsum('ijkl,ik->ijl', phi_cat, weights_near).reshape(K, batch_size, N, 4).transpose(0,
1).reshape(
batch_size * K, N, 4)
sdf = output[:, :, 0]
gradient = output[:, :, 1:]
# sdf
sdf = sdf + res_x.norm(dim=-1)
sdf = sdf.reshape(batch_size, K, N)
sdf = sdf * (scale.reshape(batch_size, K).unsqueeze(-1))
sdf_value, idx = sdf.min(dim=1)
# derivative
gradient = res_x + torch.nn.functional.normalize(gradient, dim=-1)
gradient = torch.nn.functional.normalize(gradient, dim=-1).float()
# gradient = gradient.reshape(batch_size,K,N,3)
fk_rotation = fk_trans[:, :3, :3]
gradient_base_frame = torch.einsum('ijk,ikl->ijl', fk_rotation, gradient.transpose(1, 2)).transpose(1,
2).reshape(
batch_size, K, N, 3)
# norm_gradient_base_frame = torch.linalg.norm(gradient_base_frame,dim=-1)
# exit()
# print(norm_gradient_base_frame)
idx = idx.unsqueeze(1).unsqueeze(-1).expand(batch_size, K, N, 3)
gradient_value = torch.gather(gradient_base_frame, 1, idx)[:, 0, :, :]
# gradient_value = None
return sdf_value, gradient_value
[docs] def get_whole_body_sdf_with_joints_grad_batch(self, points, joint_value, model, base_trans=None, used_links=None):
"""
Get the SDF value and gradient of the whole body with respect to the joints
:param points: (batch_size, 3)
:param joint_value: (batch_size, joint_num)
:param model: the trained RDF model
:param base_trans: the transformation matrix of base pose, (1, 4, 4)
:param used_links: the links to be used, list of link names
:return:
"""
delta = 0.001
batch_size = joint_value.shape[0]
joint_num = joint_value.shape[1]
link_num = len(self.robot.get_link_list())
joint_value = joint_value.unsqueeze(1)
d_joint_value = (joint_value.expand(batch_size, joint_num, joint_num) + torch.eye(joint_num,
device=self.device).unsqueeze(
0).expand(batch_size, joint_num, joint_num) * delta).reshape(batch_size, -1, joint_num)
joint_value = torch.cat([joint_value, d_joint_value], dim=1).reshape(batch_size * (joint_num + 1), joint_num)
if base_trans is not None:
base_trans = base_trans.unsqueeze(1).expand(batch_size, (joint_num + 1), 4, 4).reshape(
batch_size * (joint_num + 1), 4, 4)
sdf, _ = self.get_whole_body_sdf_batch(points, joint_value, model, base_trans=base_trans, use_derivative=False,
used_links=used_links)
sdf = sdf.reshape(batch_size, (joint_num + 1), -1)
d_sdf = (sdf[:, 1:, :] - sdf[:, :1, :]) / delta
return sdf[:, 0, :], d_sdf.transpose(1, 2)
[docs] def get_whole_body_normal_with_joints_grad_batch(self, points, joint_value, model, base_trans=None,
used_links=None):
"""
Get the normal vector of the whole body with respect to the joints
:param points: (batch_size, 3)
:param joint_value: (batch_size, joint_num)
:param model: the trained RDF model
:param base_trans: the transformation matrix of base pose, (1, 4, 4)
:param used_links: the links to be used, list of link names
:return:
"""
delta = 0.001
batch_size = joint_value.shape[0]
joint_num = joint_value.shape[1]
link_num = len(self.robot.get_link_list())
joint_value = joint_value.unsqueeze(1)
d_joint_value = (joint_value.expand(batch_size, joint_num, joint_num) +
torch.eye(joint_num, device=self.device).unsqueeze(0).expand(batch_size, joint_num,
joint_num) * delta).reshape(
batch_size, -1, joint_num)
joint_value = torch.cat([joint_value, d_joint_value], dim=1).reshape(batch_size * (joint_num + 1), joint_num)
if base_trans is not None:
base_trans = base_trans.unsqueeze(1).expand(batch_size, (joint_num + 1), 4, 4).reshape(
batch_size * (joint_num + 1), 4, 4)
sdf, normal = self.get_whole_body_sdf_batch(points, joint_value, model, base_trans=base_trans,
use_derivative=True, used_links=used_links)
normal = normal.reshape(batch_size, (joint_num + 1), -1, 3).transpose(1, 2)
return normal # normal size: (batch_size,N,8,3) normal[:,:,0,:] origin normal vector normal[:,:,1:,:] derivatives with respect to joints
[docs] def visualize_reconstructed_whole_body(self, model, trans_list, tag):
"""
Visualize the reconstructed whole body
:param model: the trained RDF model
:param trans_list: the transformation matrices of all links
:param tag: the tag of the mesh, e.g., 'BP_8'
:return:
"""
view_mat = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
scene = trimesh.Scene()
for link_name, origin_mf in self.link_mesh_map.items():
if origin_mf is not None:
mesh_name = origin_mf.split('/')[-1].split('.')[0]
mf = os.path.join(self.robot_asset_root, f"rdf/output_meshes/{tag}_{mesh_name}.stl")
mesh = trimesh.load(mf)
mesh_dict = model[mesh_name]
offset = mesh_dict['offset'].cpu().numpy()
scale = mesh_dict['scale']
mesh.vertices = mesh.vertices * scale + offset
all_related_link = [key for key in trans_list.keys() if link_name in key]
related_link = all_related_link[-1]
mesh.apply_transform(trans_list[related_link].squeeze().cpu().numpy())
mesh.apply_transform(view_mat)
scene.add_geometry(mesh)
scene.show()
[docs]def job(args):
return sample_sdf_points(args[0], args[1], args[2])
[docs]def sample_sdf_points(mf, mesh_name, save_path):
print(f'Sampling points for mesh {mesh_name}...')
mesh = trimesh.load(mf)
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
center = mesh.bounding_box.centroid
scale = np.max(np.linalg.norm(mesh.vertices - center, axis=1))
# sample points near surface (as same as deepSDF)
near_points, near_sdf = mesh_to_sdf.sample_sdf_near_surface(mesh,
number_of_points=500000,
surface_point_method='scan',
sign_method='normal',
scan_count=100,
scan_resolution=400,
sample_point_count=10000000,
normal_sample_count=100,
min_size=0.0,
return_gradients=False)
# # sample points randomly within the bounding box [-1,1]
random_points = np.random.rand(500000, 3) * 2.0 - 1.0
random_sdf = mesh_to_sdf.mesh_to_sdf(mesh,
random_points,
surface_point_method='scan',
sign_method='normal',
bounding_radius=None,
scan_count=100,
scan_resolution=400,
sample_point_count=10000000,
normal_sample_count=100)
# save data
data = {
'mesh_name': mesh_name,
'near_points': near_points,
'near_sdf': near_sdf,
'random_points': random_points,
'random_sdf': random_sdf,
'center': center,
'scale': scale
}
np.save(os.path.join(save_path, f'voxel_128_{mesh_name}.npy'), data)
print(f'Sampling points for mesh {mesh_name} finished!')
return data