import os.path as osp import random import numpy as np import torch import yaml import omegaconf from omegaconf import OmegaConf from scipy.spatial import cKDTree from pathlib import Path class KNNSearch(object): DTYPE = np.float32 NJOBS = 4 def __init__(self, data): self.data = np.asarray(data, dtype=self.DTYPE) self.kdtree = cKDTree(self.data) def query(self, kpts, k, return_dists=False): kpts = np.asarray(kpts, dtype=self.DTYPE) nndists, nnindices = self.kdtree.query(kpts, k=k, workers=self.NJOBS) if return_dists: return nnindices, nndists else: return nnindices def query_ball(self, kpt, radius): kpt = np.asarray(kpt, dtype=self.DTYPE) assert kpt.ndim == 1 nnindices = self.kdtree.query_ball_point(kpt, radius, n_jobs=self.NJOBS) return nnindices def validate_str(x): return x is not None and x != '' def hashing(arr, M): assert isinstance(arr, np.ndarray) and arr.ndim == 2 N, D = arr.shape hash_vec = np.zeros(N, dtype=np.int64) for d in range(D): hash_vec += arr[:, d] * M**d return hash_vec def omegaconf_to_dotdict(hparams): def _to_dot_dict(cfg): res = {} for k, v in cfg.items(): if v is None: res[k] = v elif isinstance(v, omegaconf.DictConfig): res.update({k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}) elif isinstance(v, (str, int, float, bool)): res[k] = v elif isinstance(v, omegaconf.ListConfig): res[k] = omegaconf.OmegaConf.to_container(v, resolve=True) else: raise RuntimeError('The type of {} is not supported.'.format(type(v))) return res return _to_dot_dict(hparams) def incrange(start, end, step): assert step > 0 res = [start] if start <= end: while res[-1] + step <= end: res.append(res[-1] + step) else: while res[-1] - step >= end: res.append(res[-1] - step) return res def seeding(seed=0): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True def run_trainer(trainer_cls): cfg_cli = OmegaConf.from_cli() assert cfg_cli.run_mode is not None if cfg_cli.run_mode == 'train': assert cfg_cli.run_cfg is not None cfg = OmegaConf.merge( OmegaConf.load(cfg_cli.run_cfg), cfg_cli, ) OmegaConf.resolve(cfg) cfg = omegaconf_to_dotdict(cfg) seeding(cfg['seed']) trainer = trainer_cls(cfg) trainer.train() trainer.test() elif cfg_cli.run_mode == 'test': assert cfg_cli.run_ckpt is not None log_dir = str(Path(cfg_cli.run_ckpt).parent) cfg = OmegaConf.merge( OmegaConf.load(osp.join(log_dir, 'config.yml')), cfg_cli, ) OmegaConf.resolve(cfg) cfg = omegaconf_to_dotdict(cfg) cfg['test_ckpt'] = cfg_cli.run_ckpt seeding(cfg['seed']) trainer = trainer_cls(cfg) trainer.test() else: raise RuntimeError(f'Mode {cfg_cli.run_mode} is not supported.')