Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renderer surface norm init #24

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions b3d/enumerative_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,46 @@ def _enumerate_and_select_best_move(trace, addressses, key, all_deltas):
_enumerate_and_select_best_move, static_argnames=["addressses"]
)

def _enumerate_and_return_scores(trace, addressses, key, all_deltas):
addr = addressses.const[0]
current_pose = trace[addr]
for i in range(len(all_deltas)):
test_poses = current_pose @ all_deltas[i]
potential_scores = b3d.enumerate_choices_get_scores(
trace, jax.random.PRNGKey(0), addressses, test_poses
)
return test_poses, potential_scores


enumerate_and_return_scores = jax.jit(
_enumerate_and_return_scores, static_argnames=["addressses"]
)

def _enumerate_and_sample(trace, addressses, key, all_deltas):
addr = addressses.const[0]
test_poses = trace[addr] @ all_deltas
test_poses_batches = test_poses.split(10)
scores = jnp.concatenate(
[
b3d.enumerate_choices_get_scores(
trace, key, genjax.Pytree.const(addr), poses
)
for poses in test_poses_batches
]
)
trace = b3d.update_choices(
trace,
jax.random.PRNGKey(0),
genjax.Pytree.const(addr),
test_poses[scores.argmax()],
)
key = jax.random.split(key, 2)[-1]
return trace, key


enumerate_and_sample = jax.jit(
_enumerate_and_sample, static_argnames=["addressses"]
)

def _gvmf_and_select_best_move(trace, key, variance, concentration, address, number):
addr = address.const
Expand Down
73 changes: 72 additions & 1 deletion b3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_rgb_depth_inliers_from_trace(trace):
return get_rgb_depth_inliers_from_observed_rendered_args(observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args)

def get_rgb_depth_inliers_from_observed_rendered_args(observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args):
# conversion from lightness-based color model to intrinsic image-based color model
observed_lab = b3d.rgb_to_lab(observed_rgb)
rendered_lab = b3d.rgb_to_lab(rendered_rgb)
error = (
Expand Down Expand Up @@ -100,6 +101,7 @@ def logpdf(self, observed, rendered_rgb, rendered_depth, model_args, fx, fy, far

corrected_depth = rendered_depth + (rendered_depth == 0.0) * far
areas = (corrected_depth / fx) * (corrected_depth / fy)
#areas = jnp.ones(areas.shape)

return jnp.log(
# This is leaving out a 1/A (which does depend upon the scene)
Expand All @@ -110,6 +112,39 @@ def logpdf(self, observed, rendered_rgb, rendered_depth, model_args, fx, fy, far

rgbd_sensor_model = RGBDSensorModel()

class RGBDSensorModelNorm(ExactDensity,genjax.JAXGenerativeFunction):
def sample(self, key, rendered_rgb, rendered_depth, norm_im, model_args, fx, fy, far):
return (rendered_rgb, rendered_depth)

def logpdf(self, observed, rendered_rgb, rendered_depth, norm_im, model_args, fx, fy, far):
observed_rgb, observed_depth = observed

inliers, color_inliers, depth_inliers, outliers, undecided, valid_data_mask = get_rgb_depth_inliers_from_observed_rendered_args(
observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args
)

inlier_score = model_args.inlier_score
outlier_prob = model_args.outlier_prob
multiplier = model_args.color_multiplier

corrected_depth = rendered_depth + (rendered_depth == 0.0) * far

flat_cos = jnp.abs(norm_im @ jnp.array([0,0,1]))
# adding in a pseudo back plane
inv_clip_cos = jnp.multiply(1/jnp.clip(flat_cos, 0.01, 1), (rendered_depth != 0.0) * 1) + (rendered_depth == 0.0) * 1.0
depth_corr = jnp.multiply(inv_clip_cos, corrected_depth)
areas_flat = ((depth_corr / fx) * (depth_corr / fy))

return jnp.log(
# This is leaving out a 1/A (which does depend upon the scene)
inlier_score * jnp.sum(inliers * areas_flat) +
1.0 * jnp.sum(undecided * areas_flat) +
outlier_prob * jnp.sum(outliers * areas_flat)
) * multiplier

rgbd_sensor_model_surfacenorm = RGBDSensorModelNorm()


def model_multiobject_gl_factory(renderer, image_likelihood=rgbd_sensor_model):
@genjax.static_gen_fn
def model(
Expand Down Expand Up @@ -145,6 +180,42 @@ def model(
return (observed_rgb, rendered_rgb), (observed_depth, rendered_depth)
return model

def model_multiobject_gl_factory_normal(renderer, image_likelihood=rgbd_sensor_model):
@genjax.static_gen_fn
def model(
_num_obj_arr, # new
model_args,
object_library
):

object_poses = Pose(jnp.zeros((0,3)), jnp.zeros((0,4)))
object_indices = jnp.empty((0,), dtype=int)
camera_pose = uniform_pose(jnp.ones(3)*-100.0, jnp.ones(3)*100.0) @ f"camera_pose"

for i in range(_num_obj_arr.shape[0]):
object_identity = uniform_discrete(jnp.arange(-1, len(object_library.ranges))) @ f"object_{i}"
object_indices = jnp.concatenate((object_indices, jnp.array([object_identity])))

object_pose = uniform_pose(jnp.ones(3)*-100.0, jnp.ones(3)*100.0) @ f"object_pose_{i}"
object_poses = Pose.concatenate_poses([object_poses, camera_pose.inv() @ object_pose[None,...]])

rendered_rgb, rendered_depth, rendered_norm_im = renderer.render_attribute_normal(
object_poses,
object_library.vertices,
object_library.faces,
object_library.ranges[object_indices] * (object_indices >= 0).reshape(-1,1),
object_library.attributes
)
observed_rgb, observed_depth = image_likelihood(
rendered_rgb, rendered_depth,
rendered_norm_im,
model_args,
renderer.fx, renderer.fy,
1.0
) @ "observed_rgb_depth"
return (observed_rgb, rendered_rgb), (observed_depth, rendered_depth)
return model

def get_rendered_rgb_depth_from_trace(trace):
(observed_rgb, rendered_rgb), (observed_depth, rendered_depth) = trace.get_retval
return (rendered_rgb, rendered_depth)
Expand Down Expand Up @@ -200,4 +271,4 @@ def rerun_visualize_trace_t(trace, t, modes=["rgb", "depth", "inliers"]):
pose.apply(vertices),
colors=(attributes * 255).astype(jnp.uint8),
),
)
)
47 changes: 47 additions & 0 deletions b3d/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,53 @@ def camera_from_position_and_target(
rotation_matrix = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)])
return Pose(position, Rot.from_matrix(rotation_matrix).as_quat())

def rotation_from_axis_angle(axis, angle):
"""Creates a rotation matrix from an axis and angle.

Args:
axis (jnp.ndarray): The axis vector. Shape (3,)
angle (float): The angle in radians.
Returns:
jnp.ndarray: The rotation matrix. Shape (3, 3)
"""
sina = jnp.sin(angle)
cosa = jnp.cos(angle)
direction = axis / jnp.linalg.norm(axis)
# rotation matrix around unit vector
R = jnp.diag(jnp.array([cosa, cosa, cosa]))
R = R + jnp.outer(direction, direction) * (1.0 - cosa)
direction = direction * sina
R = R + jnp.array(
[
[0.0, -direction[2], direction[1]],
[direction[2], 0.0, -direction[0]],
[-direction[1], direction[0], 0.0],
]
)
return R

def from_rot(rotation):
"""Creates a pose matrix from a rotation matrix.

Args:
rotation (jnp.ndarray): The rotation matrix. Shape (3, 3)
Returns:
Pose object
"""
return Pose.from_matrix(jnp.vstack(
[jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])]
))

def from_axis_angle(axis, angle):
"""Creates a pose matrix from an axis and angle.

Args:
axis (jnp.ndarray): The axis vector. Shape (3,)
angle (float): The angle in radians.
Returns:
Pose object
"""
return from_rot(rotation_from_axis_angle(axis, angle))

@register_pytree_node_class
class Pose:
Expand Down
89 changes: 89 additions & 0 deletions b3d/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes):
return image[0], zs[0]


def render_attribute_normal_many(self, poses, vertices, faces, ranges, attributes):
"""
Render many scenes to an image by rasterizing and then interpolating attributes.

Parameters:
poses: float array, shape (num_scenes, num_objectsß, 4, 4)
Object pose matrix.
vertices: float array, shape (num_vertices, 3)
Vertex position matrix.
faces: int array, shape (num_triangles, 3)
Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix.
ranges: int array, shape (num_objects, 2)
Ranges matrix with the 2 elements specify start indices and counts into faces.
attributes: float array, shape (num_vertices, num_attributes)
Attributes corresponding to the vertices

Outputs:
image: float array, shape (num_scenes, height, width, num_attributes)
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
any triangle the value at that pixel will be 0s.
zs: float array, shape (num_scenes, height, width)
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
norm_im: approximate surface normal image (num_scenes, height, width, 3)
"""
uvs, object_ids, triangle_ids, zs = self.rasterize_many(
poses, vertices, faces, ranges
)
mask = object_ids > 0

interpolated_values = self.interpolate_many(
attributes, uvs, triangle_ids, faces
)
image = interpolated_values * mask[..., None]

def apply_pose(pose, points):
return pose.apply(points)

pose_apply_map = jax.vmap(apply_pose, (0,None))
new_vertices = pose_apply_map(poses, vertices[faces])

def normal_vec(x,y,z):
vec = jnp.cross(y - x, z - x)
norm_vec = vec / jnp.linalg.norm(vec)
return norm_vec

normal_vec_vmap = jax.vmap(jax.vmap(normal_vec, (0,0,0)))
nvecs = normal_vec_vmap(new_vertices[...,0,:], new_vertices[...,1,:], new_vertices[...,2,:])
norm_vecs = jnp.concatenate((jnp.zeros((len(nvecs),1,3)), nvecs),axis=1)

def indexer(transformed_normals, triangle_ids):
return transformed_normals[triangle_ids]

index_map = jax.vmap(indexer, (0,0))
norm_im = index_map(norm_vecs, triangle_ids)

return image, zs, norm_im

def render_attribute_normal(self, pose, vertices, faces, ranges, attributes):
"""
Render a single scenes to an image by rasterizing and then interpolating attributes.

Parameters:
poses: float array, shape (num_objects, 4, 4)
Object pose matrix.
vertices: float array, shape (num_vertices, 3)
Vertex position matrix.
faces: int array, shape (num_triangles, 3)
Faces Triangle matrix. The integers correspond to rows in the vertices matrix.
ranges: int array, shape (num_objects, 2)
Ranges matrix with the 2 elements specify start indices and counts into faces.
attributes: float array, shape (num_vertices, num_attributes)
Attributes corresponding to the vertices

Outputs:
image: float array, shape (height, width, num_attributes)
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
any triangle the value at that pixel will be 0s.
zs: float array, shape (height, width)
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
norm_im: approximate surface normal image (height, width, 3)
"""
image, zs, norm_im = self.render_attribute_normal_many(
pose[None, ...], vertices, faces, ranges, attributes
)
return image[0], zs[0], norm_im[0]


# XLA array layout in memory
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
Expand Down
16 changes: 16 additions & 0 deletions b3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values):
enumerate_choices_get_scores, static_argnums=(2,)
)

def unproject_depth(depth, renderer):
"""Unprojects a depth image into a point cloud.

Args:
depth (jnp.ndarray): The depth image. Shape (H, W)
intrinsics (b.camera.Intrinsics): The camera intrinsics.
Returns:
jnp.ndarray: The point cloud. Shape (H, W, 3)
"""
mask = (depth < renderer.far) * (depth > renderer.near)
depth = depth * mask + renderer.far * (1.0 - mask)
y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]]
x = (x - renderer.cx) / renderer.fx
y = (y - renderer.cy) / renderer.fy
point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None]
return point_cloud_image

def nn_background_segmentation(images):
import torch
Expand Down
Loading