import sys import os.path as osp import numpy as np import torch from collections import defaultdict ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../') if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) DATA_DIRS = { 'faust': 'FAUST_r', 'faust_ori': 'FAUST_r_ori', 'scape': 'SCAPE_r', 'scape_ori': 'SCAPE_r_ori', 'smalr': 'SMAL_r', 'smalr_ori': 'SMAL_r_ori', 'shrec19': 'SHREC_r', 'shrec19_ori': 'SHREC_r_ori', 'dt4d': 'DT4D_r', 'dt4dintra': 'DT4D_r', 'dt4dintra_ori': 'DT4D_r_ori', 'dt4dinter': 'DT4D_r', 'dt4dinter_ori': 'DT4D_r_ori', 'tosca': 'TOSCA_r', 'tosca_ori': 'TOSCA_r', } def get_data_dirs(root, name, mode): prefix = osp.join(root, DATA_DIRS[name]) shape_dir = osp.join(prefix, 'shapes') corr_dir = osp.join(prefix, 'correspondences') return shape_dir, DATA_DIRS[name], corr_dir # def collate_default(data_list): # data_dict = defaultdict(list) # for pair_dict in data_list: # for k, v in pair_dict.items(): # data_dict[k].append(v) # for k in data_dict.keys(): # if k.startswith('fmap') or k.startswith('evals') or k.endswith('_sub'): # data_dict[k] = np.stack(data_dict[k], axis=0) # batch_size = len(data_list) # for k, v in data_dict.items(): # assert len(v) == batch_size # return data_dict def prepare_batch(data_dict, device): for k in data_dict.keys(): if isinstance(data_dict[k], np.ndarray): data_dict[k] = torch.from_numpy(data_dict[k]).to(device) else: if k.startswith('gradX') or \ k.startswith('gradY') or \ k.startswith('L'): from diffusion_net.utils import sparse_np_to_torch tmp_list = [sparse_np_to_torch(st).to(device) for st in data_dict[k]] if len(data_dict[k]) == 1: data_dict[k] = torch.stack(tmp_list, dim=0) else: data_dict[k] = tmp_list else: if isinstance(data_dict[k][0], np.ndarray): tmp_list = [torch.from_numpy(st).to(device) for st in data_dict[k]] if len(data_dict[k]) == 1: data_dict[k] = torch.stack(tmp_list, dim=0).to(device) else: data_dict[k] = tmp_list return data_dict