kleinhe
init
c3d0293
import random
import torch
import torch.nn.functional as F
import nvdiffrast.torch as dr
from . import utils
from lib.common.obj import compute_normal
class Renderer(torch.nn.Module):
def __init__(self):
super().__init__()
# self.glctx = dr.RasterizeCudaContext()
# self.glctx = dr.RasterizeGLContext()
try:
self.glctx = dr.RasterizeCudaContext()
except:
self.glctx = dr.RasterizeGLContext()
def forward(self, mesh, mvp,
h=512,
w=512,
light_d=None,
ambient_ratio=1.,
shading='albedo',
spp=1,
mlp_texture=None,
is_train=False):
"""
Args:
spp:
return_normal:
transform_nml:
mesh: Mesh object
mvp: [batch, 4, 4]
h: int
w: int
light_d:
ambient_ratio: float
shading: str shading type albedo, normal,
ssp: int
Returns:
color: [batch, h, w, 3]
alpha: [batch, h, w, 1]
depth: [batch, h, w, 1]
"""
B = mvp.shape[0]
v_clip = torch.bmm(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).expand(B, -1, -1),
torch.transpose(mvp, 1, 2)).float() # [B, N, 4]
res = (int(h * spp), int(w * spp)) if spp > 1 else (h, w)
rast, rast_db = dr.rasterize(self.glctx, v_clip, mesh.f, res)
################################################################################
# Interpolate attributes
################################################################################
# Interpolate world space position
alpha, _ = dr.interpolate(torch.ones_like(v_clip[..., :1]), rast, mesh.f) # [B, H, W, 1]
depth = rast[..., [2]] # [B, H, W]
if is_train:
vn, _ = compute_normal(v_clip[0, :, :3], mesh.f)
normal, _ = dr.interpolate(vn[None, ...].float(), rast, mesh.f)
else:
normal, _ = dr.interpolate(mesh.vn[None, ...].float(), rast, mesh.f)
# Texture coordinate
if not shading == 'normal':
if mlp_texture is not None:
albedo = self.get_mlp_texture(mesh, mlp_texture, rast, rast_db)
else:
albedo = self.get_2d_texture(mesh, rast, rast_db)
if shading == 'normal':
color = (normal + 1) / 2.
elif shading == 'albedo':
color = albedo
else: # lambertian
lambertian = ambient_ratio + (1 - ambient_ratio) * (normal @ light_d.view(-1, 1)).float().clamp(min=0)
color = albedo * lambertian.repeat(1, 1, 1, 3)
normal = (normal + 1) / 2.
normal = dr.antialias(normal, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3]
color = dr.antialias(color, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3]
alpha = dr.antialias(alpha, rast, v_clip, mesh.f).clamp(0, 1) # [H, W, 3]
# inverse super-sampling
if spp > 1:
color = utils.scale_img_nhwc(color, (h, w))
alpha = utils.scale_img_nhwc(alpha, (h, w))
normal = utils.scale_img_nhwc(normal, (h, w))
return color, normal, alpha
def get_mlp_texture(self, mesh, mlp_texture, rast, rast_db, res=2048):
# uv = mesh.vt[None, ...] * 2.0 - 1.0
uv = mesh.vt[None, ...]
# pad to four component coordinate
uv4 = torch.cat((uv, torch.zeros_like(uv[..., 0:1]), torch.ones_like(uv[..., 0:1])), dim=-1)
# rasterize
_rast, _ = dr.rasterize(self.glctx, uv4, mesh.f.int(), (res, res))
print("_rast ", _rast.shape)
# Interpolate world space position
# gb_pos, _ = dr.interpolate(mesh.v[None, ...], _rast, mesh.f.int())
# Sample out textures from MLP
tex = mlp_texture.sample(_rast[..., :-1].view(-1, 3)).view(*_rast.shape[:-1], 3)
texc, texc_db = dr.interpolate(mesh.vt[None, ...], rast, mesh.ft, rast_db=rast_db, diff_attrs='all')
print(tex.shape)
albedo = dr.texture(
tex, texc, uv_da=texc_db, filter_mode='linear-mipmap-linear') # [B, H, W, 3]
# albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background
# print(tex.shape, albedo.shape)
# exit()
return albedo
@staticmethod
def get_2d_texture(mesh, rast, rast_db):
texc, texc_db = dr.interpolate(mesh.vt[None, ...], rast, mesh.ft, rast_db=rast_db, diff_attrs='all')
albedo = dr.texture(
mesh.albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear-mipmap-linear') # [B, H, W, 3]
albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background
return albedo