Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import scipy | |
| import os | |
| import sys | |
| import random | |
| import numpy as np | |
| import torch | |
| import time | |
| from datetime import datetime | |
| import importlib | |
| import json | |
| import argparse | |
| from omegaconf import OmegaConf | |
| from snk.loss import PrismRegularizationLoss | |
| from snk.prism_decoder import PrismDecoder | |
| from shape_models.fmap import DFMNet | |
| from shape_models.encoder import Encoder | |
| from diffu_models.losses import VELoss, VPLoss, EDMLoss | |
| from diffu_models.sds import guidance_grad | |
| from utils.torch_fmap import torch_zoomout, knnsearch, extract_p2p_torch_fmap | |
| from utils.utils_func import convert_dict, str_delta, ensure_pretrained_file | |
| from utils.eval import accuracy | |
| from utils.mesh import save_ply, load_mesh | |
| from shape_data import get_data_dirs | |
| from utils.pickle_stuff import safe_load_with_fallback | |
| from utils.geometry import compute_operators, load_operators | |
| from utils.surfaces import Surface | |
| import sys | |
| import spaces | |
| try: | |
| import google.colab | |
| print("Running Colab") | |
| from tqdm import tqdm | |
| except ImportError: | |
| print("Running local") | |
| from tqdm.auto import tqdm | |
| def seed_everything(seed=42): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| seed_everything() | |
| class Tee: | |
| def __init__(self, *outputs): | |
| self.outputs = outputs | |
| def write(self, message): | |
| for output in self.outputs: | |
| output.write(message) | |
| output.flush() # ensure it's written immediately | |
| def flush(self): | |
| for output in self.outputs: | |
| output.flush() | |
| class DiffModel: | |
| def __init__(self, cfg, device="cpu"): | |
| if cfg["train_dir"] == "pretrained": | |
| url = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/network-snapshot-041216.pkl" | |
| network_pkl = ensure_pretrained_file(url, "pretrained") | |
| url_json = "https://huggingface.co/daidedou/diffumatch_model/resolve/main/training_options.json" | |
| json_filename = ensure_pretrained_file(url_json, "pretrained", filename="training_options.json") | |
| train_cfg = json.load(open(json_filename)) | |
| else: | |
| num_exp = cfg["diff_num_exp"] | |
| files = os.listdir(cfg["train_dir"]) | |
| for file in files: | |
| if file[:5] == f"{num_exp:05d}": | |
| netdir = os.path.join(cfg["train_dir"], file) | |
| train_cfg = json.load(open(os.path.join(netdir, "training_options.json"))) | |
| pkls = [f for f in os.listdir(netdir) if ".pkl" in f] | |
| nice_pkls = sorted(pkls, key=lambda x: int(x.split(".")[0].split("-")[-1])) | |
| chosen_pkl = nice_pkls[-1] | |
| network_pkl = os.path.join(netdir, chosen_pkl) | |
| print(f'Loading network from "{network_pkl}"...') | |
| self.net = safe_load_with_fallback(network_pkl)['ema'].to(device) | |
| print('Done!') | |
| loss_name = train_cfg['hyper_params']['loss_name'] | |
| self.loss_sde = None | |
| if loss_name == "EDMLoss": | |
| self.loss_sde = EDMLoss() | |
| elif loss_name == "VPLoss": | |
| self.loss_sde = VPLoss() | |
| class Matcher(object): | |
| def __init__(self, cfg): | |
| self.cfg = cfg | |
| self.diffusion_model = None | |
| if self.cfg.get("sds", False): | |
| self.diffusion_model = DiffModel(cfg["sds_conf"], "cpu") | |
| self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"] | |
| self.n_loop = 0 | |
| if self.cfg.get("optimize", False): | |
| self.n_loop = self.cfg.opt.get("n_loop", 0) | |
| self.snk = self.cfg.get("snk", False) | |
| self.fmap_cfg = self.cfg.deepfeat_conf.fmap | |
| self.dataloaders = dict() | |
| def reconf(self, cfg): | |
| self.cfg = cfg | |
| self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"] | |
| self.n_loop = 0 | |
| if self.cfg.get("optimize", False): | |
| self.n_loop = self.cfg.opt.get("n_loop", 0) | |
| self.fmap_cfg = self.cfg.deepfeat_conf.fmap | |
| self.dataloaders = dict() | |
| def _init(self): | |
| cfg = self.cfg | |
| self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).cuda() | |
| if self.snk: | |
| self.encoder = Encoder().cuda() | |
| self.decoder = PrismDecoder(dim_in=515).cuda() | |
| self.loss_prism = PrismRegularizationLoss(primo_h=0.02) | |
| self.soft_p2p = True | |
| params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list( | |
| self.decoder.parameters()) | |
| else: | |
| params_to_opt = self.fmap_model.parameters() | |
| self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99)) | |
| self.eye = torch.eye(self.n_fmap).float().cuda() | |
| self.eye.requires_grad = False | |
| def fmap(self, shape_dict, target_dict): | |
| diff_model_cuda = self.diffusion_model | |
| diff_model_cuda.net.cuda() | |
| if self.fmap_cfg.get("use_diff", False): | |
| C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model( | |
| {"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda, | |
| scale=self.fmap_cfg.diffusion.time) | |
| C12_pred, C12_obj, mask_12 = C12_pred | |
| C21_pred, C21_obj, mask_21 = C21_pred | |
| else: | |
| C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model( | |
| {"shape1": shape_dict, "shape2": target_dict}) | |
| C12_obj, C21_obj = C12_pred, C21_pred | |
| mask_12, mask_21 = None, None | |
| return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21 | |
| def zo_shot(self, shape_dict, target_dict): | |
| self._init() | |
| evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"] | |
| _, C12_mask_init, _, _, _, _, _, _, _, _ = self.fmap(shape_dict, target_dict) | |
| evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"]) | |
| new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"]) | |
| indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2) | |
| return new_FM, indKNN_new | |
| def optimize(self, shape_dict, target_dict, target_normals): | |
| self._init() | |
| diff_model_cuda = self.diffusion_model | |
| diff_model_cuda.net.cuda() | |
| evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"] | |
| C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict) | |
| evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"]) | |
| evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"]) | |
| n_verts_target = target_dict["vertices"].shape[-2] | |
| loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [], | |
| "proper": []} | |
| snk_rec = None | |
| for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']): | |
| C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, | |
| target_dict) | |
| if self.cfg.opt.soft_p2p: | |
| ### A la SNK | |
| ## P2P 2 -> 1 | |
| soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap], | |
| prod=True) | |
| C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap] | |
| soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True) | |
| ## P2P 1 -> 2 | |
| soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap], | |
| prod=True) | |
| C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap] | |
| soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True) | |
| l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum( | |
| dim=-1).mean() | |
| else: | |
| C12_new, C21_new = C12_pred, C21_pred | |
| l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye) ** 2).mean() + ( | |
| (C21_new.squeeze() @ C21_new.squeeze().T - self.eye) ** 2).mean() | |
| l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye) ** 2).mean() + ( | |
| (C21_new.squeeze() @ C12_new.squeeze() - self.eye) ** 2).mean() | |
| l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag( | |
| target_dict["evals"][:self.n_fmap]) @ C12_new) ** 2).mean() | |
| l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag( | |
| shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean() | |
| l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda() | |
| if self.snk: | |
| # Latent vector | |
| latents = self.encoder(shape_dict) | |
| latents_duplicate = latents[None, :].repeat(n_verts_target, 1) | |
| # Prism decoder | |
| feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1) | |
| snk_rec, prism, rots = self.decoder(target_dict, feats_decode) | |
| l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals) | |
| l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean() | |
| l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum( | |
| dim=-1).mean() | |
| l_sds, l_proper = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda() | |
| if self.fmap_cfg.get("use_diff", False): | |
| if self.fmap_cfg.diffusion.get("abs", False): | |
| C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze() | |
| else: | |
| C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze() | |
| grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1, | |
| batch_size=self.fmap_cfg.diffusion.batch_sds, | |
| scale_noise=self.fmap_cfg.diffusion.time, device="cuda") | |
| with torch.no_grad(): | |
| denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12 | |
| targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout) | |
| l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap, | |
| :self.n_fmap]) ** 2).mean() | |
| grad_21, _ = guidance_grad(C21_in, diff_model_cuda.net, grad_scale=1, | |
| batch_size=self.fmap_cfg.diffusion.batch_sds, | |
| scale_noise=self.fmap_cfg.diffusion.time, device="cuda") | |
| # denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21 | |
| with torch.no_grad(): | |
| denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21 | |
| targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(), | |
| self.cfg.sds_conf.zoomout) # , step=10) | |
| l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap, | |
| :self.n_fmap]) ** 2).mean() | |
| l_proper = l_proper_12 + l_proper_21 | |
| l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[ | |
| :self.n_fmap, | |
| :self.n_fmap]) ** 2).mean() | |
| l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[ | |
| :self.n_fmap, | |
| :self.n_fmap]) ** 2).mean() | |
| loss = torch.as_tensor(0.).float().to("cuda") | |
| if self.cfg.loss.get("ortho", 0) > 0: | |
| loss += self.cfg.loss.get("ortho", 0) * l_ortho | |
| if self.cfg.loss.get("bij", 0) > 0: | |
| loss += self.cfg.loss.get("bij", 0) * l_bij | |
| if self.cfg.loss.get("lap", 0) > 0: | |
| loss += self.cfg.loss.get("lap", 0) * l_lap | |
| if self.cfg.loss.get("cycle", 0) > 0: | |
| loss += self.cfg.loss.get("cycle", 0) * l_cycle | |
| if self.cfg.loss.get("mse_rec", 0) > 0: | |
| loss += self.cfg.loss.get("mse_rec", 0) * l_mse | |
| if self.cfg.loss.get("prism_rec", 0) > 0: | |
| loss += self.cfg.loss.get("prism_rec", 0) * l_prism | |
| if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False): | |
| loss += self.cfg.loss.get("sds", 0) * l_sds | |
| if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False): | |
| loss += self.cfg.loss.get("proper", 0) * l_proper | |
| loss.backward() | |
| self.optim.step() | |
| self.optim.zero_grad() | |
| loss_save["cycle"].append(l_cycle.item()) | |
| loss_save["ortho"].append(l_ortho.item()) | |
| loss_save["bij"].append(l_bij.item()) | |
| loss_save["sds"].append(l_sds.item()) | |
| loss_save["proper"].append(l_proper.item()) | |
| loss_save["mse"].append(l_mse.item()) | |
| loss_save["prism"].append(l_prism.item()) | |
| indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2) | |
| indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2) | |
| return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save | |
| def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False): | |
| shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch | |
| shape_dict_device = convert_dict(shape_dict, self.device) | |
| target_dict_device = convert_dict(target_dict, self.device) | |
| print(shape_dict_device["vertices"].device) | |
| os.makedirs(output_pair, exist_ok=True) | |
| if self.cfg["optimize"]: | |
| C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device, | |
| target_normals.to(self.device)) | |
| np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init) | |
| np.save(os.path.join(output_pair, "losses.npy"), loss_save) | |
| else: | |
| C12_new, p2p = self.zo_shot(shape_dict_device, target_dict_device) | |
| snk_rec, loss_save = None, None | |
| np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy()) | |
| np.save(os.path.join(output_pair, "p2p.npy"), p2p) | |
| if snk_rec is not None: | |
| save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(), | |
| target_dict["faces"]) | |
| if refine: | |
| evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"] | |
| evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"]) | |
| new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128) # , step=10) | |
| p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2) | |
| np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p) | |
| if eval: | |
| file_i, vts_1, vts_2 = mapinfo | |
| mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat")) | |
| A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum()) | |
| _, dist = accuracy(p2p[vts_2], vts_1, A_geod, | |
| sqrt_area=sqrt_area, | |
| return_all=True) | |
| if refine: | |
| _, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod, | |
| sqrt_area=sqrt_area, | |
| return_all=True) | |
| np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean())) | |
| return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean() | |
| return p2p, loss_save, dist.mean() | |
| return p2p, loss_save | |
| def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| # dloader = DataLoader(dataset, collate_fn=collate_default, batch_size=1) | |
| num_pairs = len(dataset) | |
| id_pair = 0 | |
| all_accs = [] | |
| all_accs_zo = [] | |
| t1 = datetime.now() | |
| save_txt = os.path.join(save_dir, "log.txt") | |
| # Open a file for writing | |
| log_file = open(save_txt, 'w') | |
| # Replace sys.stdout with Tee that writes to both console and file | |
| sys.stdout = Tee(sys.__stdout__, log_file) | |
| for batch in dset: | |
| shape_dict, _, target_dict, _, _, _ = batch | |
| print("Pair: " + shape_dict['name'] + " " + target_dict['name']) | |
| name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name']) | |
| if self.cfg.get("refine", False): | |
| _, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), | |
| eval=True, refine=True) | |
| else: | |
| _, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, | |
| refine=False) | |
| delta = datetime.now() - t1 | |
| fm_delta = str_delta(delta) | |
| remains = ((delta / (id_pair + 1)) * num_pairs) - delta | |
| fm_remains = str_delta(remains) | |
| all_accs.append(dist) | |
| accs_mean = np.mean(all_accs) | |
| if self.cfg.get("refine", False): | |
| all_accs_zo.append(dist_zo) | |
| accs_zo = np.mean(all_accs_zo) | |
| print( | |
| f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}") | |
| else: | |
| print( | |
| f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}") | |
| id_pair += 1 | |
| if self.cfg.get("refine", False): | |
| print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}") | |
| else: | |
| print(f"mean error : {np.mean(all_accs)}") | |
| sys.stdout = sys.__stdout__ | |
| def load_data(self, file, num_evecs=200, make_cache=False, factor=None): | |
| name = os.path.basename(os.path.splitext(file)[0]) | |
| cache_file = "single_" + name + ".npz" | |
| verts_shape, faces, vnormals, area_shape, center_shape = load_mesh(file, return_vnormals=True) | |
| cache_path = os.path.join(self.cfg.cache, cache_file) | |
| print("Cache is: ", cache_path) | |
| if not os.path.exists(cache_path) or make_cache: | |
| print("Computing operators ...") | |
| compute_operators(verts_shape, faces, vnormals, num_evecs, cache_path, force_save=make_cache) | |
| data_dict = load_operators(cache_path) | |
| data_dict['name'] = name | |
| data_dict_torch = convert_dict(data_dict, self.device) | |
| # batchify_dict(data_dict_torch) | |
| return data_dict_torch, area_shape | |
| def match_files(self, file_shape, file_target): | |
| batch_shape, _ = self.load_data(file_shape) | |
| batch_target, _ = self.load_data(file_target) | |
| target_surf = Surface(filename=file_target) | |
| target_normals = torch.from_numpy( | |
| target_surf.surfel / np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device) | |
| batch = batch_shape, None, batch_target, target_normals, None, None | |
| output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"]) | |
| p2p, _ = self.match(batch, output_folder, None) | |
| return batch_shape, batch_target, p2p | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets") | |
| parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset') | |
| parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location') | |
| parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store') | |
| parser.add_argument('--output', type=str, default="results", help="where to store experience results") | |
| args = parser.parse_args() | |
| arg_cfg = OmegaConf.from_dotlist( | |
| [f"{k}={v}" for k, v in vars(args).items() if v is not None] | |
| ) | |
| yaml_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(yaml_cfg, arg_cfg) | |
| dataset_name = args.dataset.lower() | |
| if cfg.get("oriented", False): | |
| dataset_name += "_ori" | |
| shape_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapeDataset') | |
| pair_cls = getattr(importlib.import_module(f'shape_data.{args.dataset.lower()}'), 'ShapePairDataset') | |
| data_dir, name_data_geo, corr_dir = get_data_dirs(args.datadir, dataset_name, 'test') | |
| name_data_geo = "_".join(name_data_geo.split("_")[:2]) | |
| dset_shape = shape_cls(data_dir, "cache/fmaps", "test", oriented=cfg.get("oriented", False)) | |
| print("Preprocessing shapes done.") | |
| dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False)) | |
| exp_time = time.strftime('%y-%m-%d_%H-%M-%S') | |
| output_logs = os.path.join(args.output, name_data_geo, exp_time) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| matcher = Matcher(cfg, device) | |
| matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir) |