File size: 2,418 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
import os.path as osp
import numpy as np
import torch
from collections import defaultdict

ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

DATA_DIRS = {
    'faust': 'FAUST_r',
    'faust_ori': 'FAUST_r_ori',
    'scape': 'SCAPE_r',
    'scape_ori': 'SCAPE_r_ori',
    'smalr': 'SMAL_r',
    'smalr_ori': 'SMAL_r_ori',
    'shrec19': 'SHREC_r',
    'shrec19_ori': 'SHREC_r_ori',
    'dt4d': 'DT4D_r',
    'dt4dintra': 'DT4D_r',
    'dt4dintra_ori': 'DT4D_r_ori',
    'dt4dinter': 'DT4D_r',
    'dt4dinter_ori': 'DT4D_r_ori',
    'tosca': 'TOSCA_r',
    'tosca_ori': 'TOSCA_r',
}


def get_data_dirs(root, name, mode):
    prefix = osp.join(root, DATA_DIRS[name])
    shape_dir = osp.join(prefix, 'shapes')
    corr_dir = osp.join(prefix, 'correspondences')
    return shape_dir, DATA_DIRS[name], corr_dir


# def collate_default(data_list):
#     data_dict = defaultdict(list)
#     for pair_dict in data_list:
#         for k, v in pair_dict.items():
#             data_dict[k].append(v)
#     for k in data_dict.keys():
#         if k.startswith('fmap') or k.startswith('evals') or k.endswith('_sub'):
#             data_dict[k] = np.stack(data_dict[k], axis=0)
#     batch_size = len(data_list)
#     for k, v in data_dict.items():
#         assert len(v) == batch_size

#     return data_dict


def prepare_batch(data_dict, device):
    for k in data_dict.keys():
        if isinstance(data_dict[k], np.ndarray):
            data_dict[k] = torch.from_numpy(data_dict[k]).to(device)
        else:
            if k.startswith('gradX') or \
               k.startswith('gradY') or \
               k.startswith('L'):
                from diffusion_net.utils import sparse_np_to_torch
                tmp_list = [sparse_np_to_torch(st).to(device) for st in data_dict[k]]
                if len(data_dict[k]) == 1:
                    data_dict[k] = torch.stack(tmp_list, dim=0)
                else:
                    data_dict[k] = tmp_list
            else:
                if isinstance(data_dict[k][0], np.ndarray):
                    tmp_list = [torch.from_numpy(st).to(device) for st in data_dict[k]]
                    if len(data_dict[k]) == 1:
                        data_dict[k] = torch.stack(tmp_list, dim=0).to(device)
                    else:
                        data_dict[k] = tmp_list

    return data_dict