Spaces:
Running
on
Zero
Running
on
Zero
daidedou
Try to fix the gpu aborted problem : processing on downsampled meshes during optimization
d408533
| """ | |
| Simple Gradio app for two-mesh initialization and run phases. | |
| - Upload two meshes (.ply, .obj, .off) | |
| - (Optional) upload a YAML config to override defaults | |
| - Adjust a few numeric settings (sane ranges). Defaults pulled from the provided YAML when present. | |
| - Click **Init** to generate "initialization maps" (here: position/normal-based vertex colors) for both meshes. | |
| - Click **Run** to simulate an iterative evolution with a progress bar, then output another pair of colored meshes. | |
| Replace the bodies of `make_initialization_maps` and `run_evolution` with your real pipeline as needed. | |
| Tested with: gradio >= 4.0, trimesh, pyyaml, numpy. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import io | |
| import time | |
| import json | |
| import shutil | |
| import tempfile | |
| from typing import Dict, Tuple, Optional | |
| from omegaconf import OmegaConf | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| import trimesh | |
| import zero_shot | |
| import yaml | |
| from utils.surfaces import Surface | |
| import notebook_helpers as helper | |
| from utils.meshplot import visu_pts | |
| from utils.fmap import FM_to_p2p | |
| from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout | |
| import torch | |
| import argparse | |
| from utils.utils_func import convert_dict | |
| # ----------------------------- | |
| # Utils | |
| # ----------------------------- | |
| SUPPORTED_EXTS = {".ply", ".obj", ".off", ".stl", ".glb", ".gltf"} | |
| def _safe_ext(path: str) -> str: | |
| for ext in SUPPORTED_EXTS: | |
| if path.lower().endswith(ext): | |
| return ext | |
| return os.path.splitext(path)[1].lower() | |
| def convert_and_show(mesh_file): | |
| os.makedirs("tmp/glbs", exist_ok=True) | |
| if mesh_file is None: | |
| return None | |
| mesh = trimesh.load(mesh_file.name) | |
| tn = int(np.random.rand()*1e10) | |
| f_name = f"tmp/glbs/mesh_{tn}.glb" | |
| mesh.export(f_name) | |
| return f_name | |
| def convert_and_show_twice(mesh_file_1, mesh_file_2): | |
| return convert_and_show(mesh_file_1), convert_and_show(mesh_file_2) | |
| def normalize_vertices(vertices: np.ndarray) -> np.ndarray: | |
| v = vertices.astype(np.float64) | |
| v = v - v.mean(axis=0, keepdims=True) | |
| scale = np.linalg.norm(v, axis=1).max() | |
| if scale == 0: | |
| scale = 1.0 | |
| v = v / scale | |
| return v.astype(np.float32) | |
| def ensure_vertex_colors(mesh: trimesh.Trimesh, colors: np.ndarray) -> trimesh.Trimesh: | |
| out = mesh.copy() | |
| if colors.shape[1] == 3: | |
| rgba = np.concatenate([colors, 255*np.ones((colors.shape[0],1), dtype=np.uint8)], axis=1) | |
| else: | |
| rgba = colors | |
| out.visual.vertex_colors = rgba | |
| return out | |
| def export_for_view(surf: Surface, colors: np.ndarray, basename: str, outdir: str) -> Tuple[str, str]: | |
| """Export to PLY (with vertex colors) and GLB for Model3D preview.""" | |
| glb_path = os.path.join(outdir, f"{basename}.glb") | |
| mesh = trimesh.Trimesh(surf.vertices, surf.faces, process=False) | |
| colored_mesh = ensure_vertex_colors(mesh, colors) | |
| colored_mesh.export(glb_path) | |
| return glb_path | |
| # ----------------------------- | |
| # Algorithm placeholders (replace with your real pipeline) | |
| # ----------------------------- | |
| DEFAULT_SETTINGS = { | |
| "deepfeat_conf.fmap.lambda_": 1, | |
| "sds_conf.zoomout": 32, | |
| "diffusion.time": 1.0, | |
| "opt.n_loop": 200, | |
| "loss.sds": 1.0, | |
| "loss.proper": 1.0, | |
| } | |
| FLOAT_SLIDERS = { | |
| # name: (min, max, step) | |
| "deepfeat_conf.fmap.lambda_": (1e-3, 10.0, 1e-3), | |
| "diffusion.time": (0.1, 10.0, 0.1), | |
| "loss.sds": (1e-3, 10.0, 1e-3), | |
| "loss.proper": (1e-3, 10.0, 1e-3), | |
| } | |
| INT_SLIDERS = { | |
| "opt.n_loop": (1, 5000, 1), | |
| "sds_conf.zoomout": (31, 50, 1), | |
| } | |
| def flatten_yaml_floats(d: Dict, prefix: str = "") -> Dict[str, float]: | |
| flat = {} | |
| for k, v in d.items(): | |
| key = f"{prefix}.{k}" if prefix else str(k) | |
| if isinstance(v, dict): | |
| flat.update(flatten_yaml_floats(v, key)) | |
| elif isinstance(v, (int, float)): | |
| flat[key] = float(v) | |
| return flat | |
| def read_yaml_defaults(yaml_path: Optional[str]) -> Dict[str, float]: | |
| if yaml_path and os.path.exists(yaml_path): | |
| with open(yaml_path, "r") as f: | |
| data = yaml.safe_load(f) | |
| flat = flatten_yaml_floats(data) | |
| # Only keep known keys we expose as controls | |
| defaults = DEFAULT_SETTINGS.copy() | |
| for k in list(DEFAULT_SETTINGS.keys()): | |
| if k in flat: | |
| defaults[k] = float(flat[k]) | |
| return defaults | |
| return DEFAULT_SETTINGS.copy() | |
| class Datadicts: | |
| def __init__(self, shape_path, target_path): | |
| self.shape_path = shape_path | |
| basename_1 = os.path.basename(shape_path) | |
| self.shape_dict, self.shape_dict_down = helper.load_data(shape_path, "tmp/" + os.path.splitext(basename_1)[0]+".npz", "source", make_cache=True) | |
| self.shape_surf = Surface(filename=shape_path) | |
| self.shape_surf_down = Surface(filename=self.shape_dict_down["file"]) | |
| self.target_path = target_path | |
| basename_2 = os.path.basename(target_path) | |
| self.target_dict, self.target_dict_down = helper.load_data(target_path, "tmp/" + os.path.splitext(basename_2)[0]+".npz", "target", make_cache=True) | |
| self.target_surf = Surface(filename=target_path) | |
| self.target_surf_down = Surface(filename=self.target_dict_down["file"]) | |
| self.cmap1 = visu_pts(self.shape_surf) | |
| self.cmap1_down = visu_pts(self.shape_surf_down) | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| TMP_ROOT = tempfile.mkdtemp(prefix="meshapp_") | |
| def save_array_txt(arr): | |
| # Create a temporary file with .txt suffix | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w") as f: | |
| np.savetxt(f, arr.astype(int), fmt="%d") # save as text | |
| return f.name | |
| def build_outputs(surf_a: Surface, surf_b: Surface, cmap_a: np.ndarray, p2p: np.ndarray, tag: str) -> Tuple[str, str, str, str]: | |
| outdir = os.path.join(TMP_ROOT, tag) | |
| os.makedirs(outdir, exist_ok=True) | |
| glb_a = export_for_view(surf_a, cmap_a, f"A_{tag}", outdir) | |
| cmap_b = cmap_a[p2p] | |
| glb_b = export_for_view(surf_b, cmap_b, f"B_{tag}", outdir) | |
| out_file = save_array_txt(p2p) | |
| return glb_a, glb_b, out_file | |
| def init_clicked(mesh1_path, mesh2_path, | |
| lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val): | |
| matcher._init() | |
| print("inside init") | |
| cfg.deepfeat_conf.fmap.lambda_ = lambda_val | |
| cfg.sds_conf.zoomout = zoomout_val | |
| cfg.deepfeat_conf.fmap.diffusion.time = time_val | |
| cfg.opt.n_loop = nloop_val | |
| cfg.loss.sds = sds_val | |
| cfg.loss.proper = proper_val | |
| matcher.reconf(cfg) | |
| if not mesh1_path or not mesh2_path: | |
| raise gr.Error("Please upload both meshes.") | |
| global datadicts | |
| datadicts = Datadicts(mesh1_path, mesh2_path) | |
| shape_dict, target_dict = convert_dict(datadicts.shape_dict_down, 'cuda'), convert_dict(datadicts.target_dict_down, 'cuda') | |
| fmap_model_cuda = matcher.fmap_model.cuda() | |
| diff_model_cuda = matcher.diffusion_model | |
| diff_model_cuda.net.cuda() | |
| C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = fmap_model_cuda({"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda, scale=matcher.fmap_cfg.diffusion.time) | |
| C12_pred, C12_obj, mask_12 = C12_pred_init | |
| evecs1, evecs2 = torch.from_numpy(datadicts.shape_dict["evecs"]).cuda(), torch.from_numpy(datadicts.target_dict["evecs"]).cuda() | |
| C_up, C_down = torch.from_numpy(datadicts.target_dict["Cup"]).cuda(), torch.from_numpy(datadicts.shape_dict_down["Cdown"]).cuda() | |
| n_fmap = C12_obj.shape[-1] | |
| with torch.no_grad(): | |
| C12_all = C_up.squeeze()[:n_fmap, :n_fmap] @ C12_obj.clone().squeeze() @ C_down.squeeze()[:n_fmap, :n_fmap] | |
| p2p_init = FM_to_p2p(C12_all.cpu().numpy(), datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"]) | |
| return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init") | |
| def run_clicked(mesh1_path, mesh2_path, yaml_path, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val, progress=gr.Progress(track_tqdm=True)): | |
| if not mesh1_path or not mesh2_path: | |
| raise gr.Error("Please upload both meshes.") | |
| cfg.deepfeat_conf.fmap.lambda_ = lambda_val | |
| cfg.sds_conf.zoomout = zoomout_val | |
| cfg.deepfeat_conf.fmap.diffusion.time = time_val | |
| cfg.opt.n_loop = nloop_val | |
| cfg.loss.sds = sds_val | |
| cfg.loss.proper = proper_val | |
| matcher.reconf(cfg) | |
| if not mesh1_path or not mesh2_path: | |
| raise gr.Error("Please upload both meshes.") | |
| matcher._init() | |
| global datadicts | |
| if datadicts is None: | |
| datadicts = Datadicts(mesh1_path, mesh2_path) | |
| elif datadicts is not None: | |
| if not (datadicts.shape_path == mesh1_path and datadicts.target_path == mesh2_path): | |
| datadicts = Datadicts(mesh1_path, mesh2_path) | |
| shape_dict, target_dict = convert_dict(datadicts.shape_dict_down, 'cuda'), convert_dict(datadicts.target_dict_down, 'cuda') | |
| target_normals = torch.from_numpy(datadicts.target_surf_down.surfel/np.linalg.norm(datadicts.target_surf_down.surfel, axis=-1, keepdims=True)).float().to("cuda") | |
| C12_new, p2p, p2p_init, _, loss_save = matcher.optimize(shape_dict, target_dict, target_normals) | |
| C_up, C_down = torch.from_numpy(datadicts.target_dict["Cup"]).cuda(), torch.from_numpy(datadicts.shape_dict_down["Cdown"]).cuda() | |
| evecs1, evecs2 = torch.from_numpy(datadicts.shape_dict["evecs"]).cuda(), torch.from_numpy(datadicts.target_dict["evecs"]).cuda() | |
| evecs_2trans = evecs2.t() @ torch.diag(torch.from_numpy(datadicts.target_dict["mass"]).cuda()) | |
| with torch.no_grad(): | |
| n_fmap = C12_new.shape[-1] | |
| C12_all = C_up.squeeze()[:n_fmap, :n_fmap] @ C12_new.clone().squeeze() @ C_down.squeeze()[:n_fmap, :n_fmap] | |
| C12_end_zo = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_all, matcher.cfg.sds_conf.zoomout) | |
| p2p_zo, _ = extract_p2p_torch_fmap(C12_end_zo, evecs1, evecs2) | |
| return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_zo, tag="run") | |
| with gr.Blocks(title="DiffuMatch demo") as demo: | |
| gr.Markdown( | |
| """ | |
| <div align="center"> | |
| <h1>DiffuMatch: Category-Agnostic Spectral Diffusion Priors for Robust Non-rigid Shape Matching</h1> | |
| </div> | |
| <br/> | |
| Upload two meshes and try our ICCV zero-shot method <a href="https://daidedou.github.io/publication/nonrigiddiff">DiffuMatch</a> <br/> | |
| <b>Init</b> will give you a rough correspondence, and you can click on <b>Run</b> to see if our method is able to match the two shapes! <br/> | |
| <b>Recommended</b/>: The method requires that the meshes are aligned (rotation-wise) to work well.<br/> | |
| The method have been adapted to the zeroGPU environment, so results won't be as good as in the paper. Also without Pykeops, the optimization is much slower. <br/> | |
| We recommend using the <a href="https://github.com/daidedou/diffumatch">offical code</a> if you want to get the best results. <br/> | |
| This method might not work with topological inconsistencies, and will crash for methods with high number of vertices (>10000) - because of the preprocessing. Try it out and let us know! <br/> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| mesh1 = gr.File(label="Source Mesh (.ply, .off, .obj)", file_types=[".ply", ".off", ".obj"]) | |
| mesh1_viewer = gr.Model3D(label="Preview Source") | |
| mesh1.upload(fn=convert_and_show, inputs=mesh1, outputs=mesh1_viewer) | |
| with gr.Column(): | |
| mesh2 = gr.File(label="Target Mesh (.ply, .off, .obj)", file_types=[".ply", ".off", ".obj"]) | |
| mesh2_viewer = gr.Model3D(label="Preview Target") | |
| mesh2.upload(fn=convert_and_show, inputs=mesh2, outputs=mesh2_viewer) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/man.ply", "examples/woman.ply"], | |
| ["examples/wolf.ply", "examples/horse.ply"], | |
| ["examples/cactus.off", "examples/cactus_deformed.off"], | |
| ], | |
| fn=convert_and_show_twice, | |
| inputs=[mesh1, mesh2], | |
| outputs=[mesh1_viewer, mesh2_viewer], | |
| label="Try some example pairs", | |
| cache_examples=True | |
| ) | |
| with gr.Accordion("Optional YAML full settings (see github to config)", open=False): | |
| yaml_file = gr.File(label="Optional YAML config", file_types=[".yaml", ".yml"], visible=True) | |
| # except Exception: | |
| with gr.Accordion("Settings", open=True): | |
| with gr.Row(): | |
| lambda_val = gr.Slider(minimum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][0], maximum=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][1], step=FLOAT_SLIDERS["deepfeat_conf.fmap.lambda_"][2], value=DEFAULT_SETTINGS["deepfeat_conf.fmap.lambda_"], label="deepfeat_conf.fmap.lambda_") | |
| zoomout_val = gr.Slider(minimum=INT_SLIDERS["sds_conf.zoomout"][0], maximum=INT_SLIDERS["sds_conf.zoomout"][1], step=INT_SLIDERS["sds_conf.zoomout"][2], value=DEFAULT_SETTINGS["sds_conf.zoomout"], label="sds_conf.zoomout") | |
| time_val = gr.Slider(minimum=FLOAT_SLIDERS["diffusion.time"][0], maximum=FLOAT_SLIDERS["diffusion.time"][1], step=FLOAT_SLIDERS["diffusion.time"][2], value=DEFAULT_SETTINGS["diffusion.time"], label="diffusion.time") | |
| with gr.Row(): | |
| nloop_val = gr.Slider(minimum=INT_SLIDERS["opt.n_loop"][0], maximum=INT_SLIDERS["opt.n_loop"][1], step=INT_SLIDERS["opt.n_loop"][2], value=DEFAULT_SETTINGS["opt.n_loop"], label="opt.n_loop") | |
| sds_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.sds"][0], maximum=FLOAT_SLIDERS["loss.sds"][1], step=FLOAT_SLIDERS["loss.sds"][2], value=1, label="loss.sds") | |
| proper_val = gr.Slider(minimum=FLOAT_SLIDERS["loss.proper"][0], maximum=FLOAT_SLIDERS["loss.proper"][1], step=FLOAT_SLIDERS["loss.proper"][2], value=1, label="loss.proper") | |
| with gr.Row(): | |
| init_btn = gr.Button("Init", variant="primary") | |
| run_btn = gr.Button("Run", variant="secondary") | |
| gr.Markdown("### Outputs\n For both **Init** and **Run** stages, \n we provide a preview of the correspondences as coloreds glbs, \n and the obtained correspondence as a .txt file.") | |
| with gr.Tab("Init"): | |
| with gr.Row(): | |
| init_view_a = gr.Model3D(label="Shape") | |
| init_view_b = gr.Model3D(label="Target correspondence (init)") | |
| with gr.Row(): | |
| out_file_init = gr.File(label="Download correspondences TXT") | |
| with gr.Tab("Run"): | |
| with gr.Row(): | |
| run_view_a = gr.Model3D(label="Shape") | |
| run_view_b = gr.Model3D(label="Target correspondence (run)") | |
| with gr.Row(): | |
| out_file = gr.File(label="Download correspondences TXT") | |
| init_btn.click( | |
| fn=init_clicked, | |
| inputs=[mesh1, mesh2, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val], | |
| outputs=[init_view_a, init_view_b, out_file_init], | |
| api_name="init", | |
| ) | |
| run_btn.click( | |
| fn=run_clicked, | |
| inputs=[mesh1, mesh2, yaml_file, lambda_val, zoomout_val, time_val, nloop_val, sds_val, proper_val], | |
| outputs=[run_view_a, run_view_b, out_file], | |
| api_name="run", | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Launch the gradio demo") | |
| parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location') | |
| parser.add_argument('--share', action="store_true") | |
| args = parser.parse_args() | |
| cfg = OmegaConf.load(args.config) | |
| print("Making matcher") | |
| matcher = zero_shot.Matcher(cfg) | |
| print("Matcher ready") | |
| #shutil.rmtree("tmp") | |
| os.makedirs("tmp", exist_ok=True) | |
| os.makedirs("tmp/plys", exist_ok=True) | |
| datadicts = None | |
| demo.launch(share=args.share) | |