diff --git a/b3d/enumerative_proposals.py b/b3d/enumerative_proposals.py index b6dc3da1..69e2e816 100644 --- a/b3d/enumerative_proposals.py +++ b/b3d/enumerative_proposals.py @@ -22,6 +22,46 @@ def _enumerate_and_select_best_move(trace, addressses, key, all_deltas): _enumerate_and_select_best_move, static_argnames=["addressses"] ) +def _enumerate_and_return_scores(trace, addressses, key, all_deltas): + addr = addressses.const[0] + current_pose = trace[addr] + for i in range(len(all_deltas)): + test_poses = current_pose @ all_deltas[i] + potential_scores = b3d.enumerate_choices_get_scores( + trace, jax.random.PRNGKey(0), addressses, test_poses + ) + return test_poses, potential_scores + + +enumerate_and_return_scores = jax.jit( + _enumerate_and_return_scores, static_argnames=["addressses"] +) + +def _enumerate_and_sample(trace, addressses, key, all_deltas): + addr = addressses.const[0] + test_poses = trace[addr] @ all_deltas + test_poses_batches = test_poses.split(10) + scores = jnp.concatenate( + [ + b3d.enumerate_choices_get_scores( + trace, key, genjax.Pytree.const(addr), poses + ) + for poses in test_poses_batches + ] + ) + trace = b3d.update_choices( + trace, + jax.random.PRNGKey(0), + genjax.Pytree.const(addr), + test_poses[scores.argmax()], + ) + key = jax.random.split(key, 2)[-1] + return trace, key + + +enumerate_and_sample = jax.jit( + _enumerate_and_sample, static_argnames=["addressses"] +) def _gvmf_and_select_best_move(trace, key, variance, concentration, address, number): addr = address.const diff --git a/b3d/model.py b/b3d/model.py index 54a888aa..da6cd7aa 100644 --- a/b3d/model.py +++ b/b3d/model.py @@ -67,6 +67,7 @@ def get_rgb_depth_inliers_from_trace(trace): return get_rgb_depth_inliers_from_observed_rendered_args(observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args) def get_rgb_depth_inliers_from_observed_rendered_args(observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args): + # conversion from lightness-based color model to intrinsic image-based color model observed_lab = b3d.rgb_to_lab(observed_rgb) rendered_lab = b3d.rgb_to_lab(rendered_rgb) error = ( @@ -100,6 +101,7 @@ def logpdf(self, observed, rendered_rgb, rendered_depth, model_args, fx, fy, far corrected_depth = rendered_depth + (rendered_depth == 0.0) * far areas = (corrected_depth / fx) * (corrected_depth / fy) + #areas = jnp.ones(areas.shape) return jnp.log( # This is leaving out a 1/A (which does depend upon the scene) @@ -110,6 +112,39 @@ def logpdf(self, observed, rendered_rgb, rendered_depth, model_args, fx, fy, far rgbd_sensor_model = RGBDSensorModel() +class RGBDSensorModelNorm(ExactDensity,genjax.JAXGenerativeFunction): + def sample(self, key, rendered_rgb, rendered_depth, norm_im, model_args, fx, fy, far): + return (rendered_rgb, rendered_depth) + + def logpdf(self, observed, rendered_rgb, rendered_depth, norm_im, model_args, fx, fy, far): + observed_rgb, observed_depth = observed + + inliers, color_inliers, depth_inliers, outliers, undecided, valid_data_mask = get_rgb_depth_inliers_from_observed_rendered_args( + observed_rgb, rendered_rgb, observed_depth, rendered_depth, model_args + ) + + inlier_score = model_args.inlier_score + outlier_prob = model_args.outlier_prob + multiplier = model_args.color_multiplier + + corrected_depth = rendered_depth + (rendered_depth == 0.0) * far + + flat_cos = jnp.abs(norm_im @ jnp.array([0,0,1])) + # adding in a pseudo back plane + inv_clip_cos = jnp.multiply(1/jnp.clip(flat_cos, 0.01, 1), (rendered_depth != 0.0) * 1) + (rendered_depth == 0.0) * 1.0 + depth_corr = jnp.multiply(inv_clip_cos, corrected_depth) + areas_flat = ((depth_corr / fx) * (depth_corr / fy)) + + return jnp.log( + # This is leaving out a 1/A (which does depend upon the scene) + inlier_score * jnp.sum(inliers * areas_flat) + + 1.0 * jnp.sum(undecided * areas_flat) + + outlier_prob * jnp.sum(outliers * areas_flat) + ) * multiplier + +rgbd_sensor_model_surfacenorm = RGBDSensorModelNorm() + + def model_multiobject_gl_factory(renderer, image_likelihood=rgbd_sensor_model): @genjax.static_gen_fn def model( @@ -145,6 +180,42 @@ def model( return (observed_rgb, rendered_rgb), (observed_depth, rendered_depth) return model +def model_multiobject_gl_factory_normal(renderer, image_likelihood=rgbd_sensor_model): + @genjax.static_gen_fn + def model( + _num_obj_arr, # new + model_args, + object_library + ): + + object_poses = Pose(jnp.zeros((0,3)), jnp.zeros((0,4))) + object_indices = jnp.empty((0,), dtype=int) + camera_pose = uniform_pose(jnp.ones(3)*-100.0, jnp.ones(3)*100.0) @ f"camera_pose" + + for i in range(_num_obj_arr.shape[0]): + object_identity = uniform_discrete(jnp.arange(-1, len(object_library.ranges))) @ f"object_{i}" + object_indices = jnp.concatenate((object_indices, jnp.array([object_identity]))) + + object_pose = uniform_pose(jnp.ones(3)*-100.0, jnp.ones(3)*100.0) @ f"object_pose_{i}" + object_poses = Pose.concatenate_poses([object_poses, camera_pose.inv() @ object_pose[None,...]]) + + rendered_rgb, rendered_depth, rendered_norm_im = renderer.render_attribute_normal( + object_poses, + object_library.vertices, + object_library.faces, + object_library.ranges[object_indices] * (object_indices >= 0).reshape(-1,1), + object_library.attributes + ) + observed_rgb, observed_depth = image_likelihood( + rendered_rgb, rendered_depth, + rendered_norm_im, + model_args, + renderer.fx, renderer.fy, + 1.0 + ) @ "observed_rgb_depth" + return (observed_rgb, rendered_rgb), (observed_depth, rendered_depth) + return model + def get_rendered_rgb_depth_from_trace(trace): (observed_rgb, rendered_rgb), (observed_depth, rendered_depth) = trace.get_retval return (rendered_rgb, rendered_depth) @@ -200,4 +271,4 @@ def rerun_visualize_trace_t(trace, t, modes=["rgb", "depth", "inliers"]): pose.apply(vertices), colors=(attributes * 255).astype(jnp.uint8), ), - ) \ No newline at end of file + ) diff --git a/b3d/pose.py b/b3d/pose.py index c736cf89..1388bc88 100644 --- a/b3d/pose.py +++ b/b3d/pose.py @@ -90,6 +90,53 @@ def camera_from_position_and_target( rotation_matrix = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)]) return Pose(position, Rot.from_matrix(rotation_matrix).as_quat()) +def rotation_from_axis_angle(axis, angle): + """Creates a rotation matrix from an axis and angle. + + Args: + axis (jnp.ndarray): The axis vector. Shape (3,) + angle (float): The angle in radians. + Returns: + jnp.ndarray: The rotation matrix. Shape (3, 3) + """ + sina = jnp.sin(angle) + cosa = jnp.cos(angle) + direction = axis / jnp.linalg.norm(axis) + # rotation matrix around unit vector + R = jnp.diag(jnp.array([cosa, cosa, cosa])) + R = R + jnp.outer(direction, direction) * (1.0 - cosa) + direction = direction * sina + R = R + jnp.array( + [ + [0.0, -direction[2], direction[1]], + [direction[2], 0.0, -direction[0]], + [-direction[1], direction[0], 0.0], + ] + ) + return R + +def from_rot(rotation): + """Creates a pose matrix from a rotation matrix. + + Args: + rotation (jnp.ndarray): The rotation matrix. Shape (3, 3) + Returns: + Pose object + """ + return Pose.from_matrix(jnp.vstack( + [jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])] + )) + +def from_axis_angle(axis, angle): + """Creates a pose matrix from an axis and angle. + + Args: + axis (jnp.ndarray): The axis vector. Shape (3,) + angle (float): The angle in radians. + Returns: + Pose object + """ + return from_rot(rotation_from_axis_angle(axis, angle)) @register_pytree_node_class class Pose: diff --git a/b3d/renderer.py b/b3d/renderer.py index d60eab02..d1a2013a 100644 --- a/b3d/renderer.py +++ b/b3d/renderer.py @@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes): return image[0], zs[0] + def render_attribute_normal_many(self, poses, vertices, faces, ranges, attributes): + """ + Render many scenes to an image by rasterizing and then interpolating attributes. + + Parameters: + poses: float array, shape (num_scenes, num_objectsß, 4, 4) + Object pose matrix. + vertices: float array, shape (num_vertices, 3) + Vertex position matrix. + faces: int array, shape (num_triangles, 3) + Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix. + ranges: int array, shape (num_objects, 2) + Ranges matrix with the 2 elements specify start indices and counts into faces. + attributes: float array, shape (num_vertices, num_attributes) + Attributes corresponding to the vertices + + Outputs: + image: float array, shape (num_scenes, height, width, num_attributes) + At each pixel the value is the barycentric interpolation of the attributes corresponding to the + 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect + any triangle the value at that pixel will be 0s. + zs: float array, shape (num_scenes, height, width) + Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle. + norm_im: approximate surface normal image (num_scenes, height, width, 3) + """ + uvs, object_ids, triangle_ids, zs = self.rasterize_many( + poses, vertices, faces, ranges + ) + mask = object_ids > 0 + + interpolated_values = self.interpolate_many( + attributes, uvs, triangle_ids, faces + ) + image = interpolated_values * mask[..., None] + + def apply_pose(pose, points): + return pose.apply(points) + + pose_apply_map = jax.vmap(apply_pose, (0,None)) + new_vertices = pose_apply_map(poses, vertices[faces]) + + def normal_vec(x,y,z): + vec = jnp.cross(y - x, z - x) + norm_vec = vec / jnp.linalg.norm(vec) + return norm_vec + + normal_vec_vmap = jax.vmap(jax.vmap(normal_vec, (0,0,0))) + nvecs = normal_vec_vmap(new_vertices[...,0,:], new_vertices[...,1,:], new_vertices[...,2,:]) + norm_vecs = jnp.concatenate((jnp.zeros((len(nvecs),1,3)), nvecs),axis=1) + + def indexer(transformed_normals, triangle_ids): + return transformed_normals[triangle_ids] + + index_map = jax.vmap(indexer, (0,0)) + norm_im = index_map(norm_vecs, triangle_ids) + + return image, zs, norm_im + + def render_attribute_normal(self, pose, vertices, faces, ranges, attributes): + """ + Render a single scenes to an image by rasterizing and then interpolating attributes. + + Parameters: + poses: float array, shape (num_objects, 4, 4) + Object pose matrix. + vertices: float array, shape (num_vertices, 3) + Vertex position matrix. + faces: int array, shape (num_triangles, 3) + Faces Triangle matrix. The integers correspond to rows in the vertices matrix. + ranges: int array, shape (num_objects, 2) + Ranges matrix with the 2 elements specify start indices and counts into faces. + attributes: float array, shape (num_vertices, num_attributes) + Attributes corresponding to the vertices + + Outputs: + image: float array, shape (height, width, num_attributes) + At each pixel the value is the barycentric interpolation of the attributes corresponding to the + 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect + any triangle the value at that pixel will be 0s. + zs: float array, shape (height, width) + Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle. + norm_im: approximate surface normal image (height, width, 3) + """ + image, zs, norm_im = self.render_attribute_normal_many( + pose[None, ...], vertices, faces, ranges, attributes + ) + return image[0], zs[0], norm_im[0] + + # XLA array layout in memory def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] diff --git a/b3d/utils.py b/b3d/utils.py index 4bc2c950..9200fa6b 100644 --- a/b3d/utils.py +++ b/b3d/utils.py @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values): enumerate_choices_get_scores, static_argnums=(2,) ) +def unproject_depth(depth, renderer): + """Unprojects a depth image into a point cloud. + + Args: + depth (jnp.ndarray): The depth image. Shape (H, W) + intrinsics (b.camera.Intrinsics): The camera intrinsics. + Returns: + jnp.ndarray: The point cloud. Shape (H, W, 3) + """ + mask = (depth < renderer.far) * (depth > renderer.near) + depth = depth * mask + renderer.far * (1.0 - mask) + y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]] + x = (x - renderer.cx) / renderer.fx + y = (y - renderer.cy) / renderer.fy + point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None] + return point_cloud_image def nn_background_segmentation(images): import torch diff --git a/test/test_likelihood_invariances.py b/test/test_likelihood_invariances.py index 3a3b2933..dff1679c 100644 --- a/test/test_likelihood_invariances.py +++ b/test/test_likelihood_invariances.py @@ -169,3 +169,96 @@ def test_distance_to_camera_invarance(renderer): assert jnp.isclose(near_score, far_score, rtol=0.03) +def test_patch_orientation_invariance(renderer): + + object_library = b3d.MeshLibrary.make_empty_library() + occluder = trimesh.creation.box(extents=jnp.array([0.0001, 0.1, 0.1])) + occluder_colors = jnp.tile(jnp.array([0.8, 0.8, 0.8])[None,...], (occluder.vertices.shape[0], 1)) + object_library = b3d.MeshLibrary.make_empty_library() + object_library.add_object(occluder.vertices, occluder.faces, attributes=occluder_colors) + + image_width = 200 + image_height = 200 + fx = 200.0 + fy = 200.0 + cx = 100.0 + cy = 100.0 + near = 0.001 + far = 16.0 + renderer.set_intrinsics(image_width, image_height, fx, fy, cx, cy, near, far) + + flat_pose = b3d.Pose.from_position_and_target( + jnp.array([0.3, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 1.0]) + ).inv() + + from b3d.pose import from_axis_angle + + transform_vec = jax.vmap(from_axis_angle, (None, 0)) + in_place_rots = transform_vec(jnp.array([0,0,1]), jnp.linspace(0, jnp.pi/4, 10)) + tilt_pose = flat_pose @ in_place_rots[5] + + rgb_flat, depth_flat = renderer.render_attribute( + flat_pose[None, ...], + object_library.vertices, + object_library.faces, + object_library.ranges, + object_library.attributes, + ) + + rgb_tilt, depth_tilt = renderer.render_attribute( + tilt_pose[None, ...], + object_library.vertices, + object_library.faces, + object_library.ranges, + object_library.attributes, + ) + + + color_error, depth_error = (50.0, 0.01) + inlier_score, outlier_prob = (4.0, 0.000001) + color_multiplier, depth_multiplier = (100.0, 1.0) + model_args = b3d.ModelArgs( + color_error, + depth_error, + inlier_score, + outlier_prob, + color_multiplier, + depth_multiplier, + ) + + from genjax.generative_functions.distributions import ExactDensity + import genjax + + + rr.log("img_near", rr.Image(rgb_flat)) + rr.log("img_far", rr.Image(rgb_tilt)) + + + + area_flat = ((depth_flat / fx) * (depth_flat / fy)).sum() + area_tilt = ((depth_tilt / fx) * (depth_tilt / fy)).sum() + print(area_flat, area_tilt) + + flat_score = ( + b3d.rgbd_sensor_model.logpdf( + (rgb_flat, depth_flat), rgb_flat, depth_flat, model_args, fx, fy, 0.0 + ) + ) + + tilt_score = ( + b3d.rgbd_sensor_model.logpdf( + (rgb_tilt, depth_tilt), rgb_tilt, depth_tilt, model_args, fx, fy, 0.0 + ) + ) + print(flat_score, tilt_score) + print(b3d.normalize_log_scores(jnp.array([flat_score, tilt_score]))) + + assert jnp.isclose(flat_score, tilt_score, rtol=0.05) + + +def test_patch_posterior_samples(renderer): + sum = 0 + + + + assert sum >= 0 \ No newline at end of file diff --git a/test/test_render_ycb_model.py b/test/test_render_ycb_model.py index e1f043cf..4c4a50a1 100644 --- a/test/test_render_ycb_model.py +++ b/test/test_render_ycb_model.py @@ -2,7 +2,11 @@ import jax.numpy as jnp import trimesh import b3d +import rerun as rr +PORT = 8812 +rr.init("real") +rr.connect(addr=f"127.0.0.1:{PORT}") def test_renderer_full(renderer): mesh_path = os.path.join( @@ -15,7 +19,7 @@ def test_renderer_full(renderer): object_library.add_trimesh(mesh) pose = b3d.Pose.from_position_and_target( - jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0]) + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) ).inv() rgb, depth = renderer.render_attribute( @@ -27,3 +31,33 @@ def test_renderer_full(renderer): ) b3d.get_rgb_pil_image(rgb).save(b3d.get_root_path() / "assets/test_results/test_ycb.png") assert rgb.sum() > 0 + +def test_renderer_normal_full(renderer): + mesh_path = os.path.join( + b3d.get_root_path(), + "assets/shared_data_bucket/ycb_video_models/models/003_cracker_box/textured_simple.obj", + ) + mesh = trimesh.load(mesh_path) + + object_library = b3d.MeshLibrary.make_empty_library() + object_library.add_trimesh(mesh) + + pose = b3d.Pose.from_position_and_target( + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) + ).inv() + + rgb, depth, normal = renderer.render_attribute_normal( + pose[None, ...], + object_library.vertices, + object_library.faces, + jnp.array([[0, len(object_library.faces)]]), + object_library.attributes, + ) + + b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png") + + point_im = b3d.utils.unproject_depth(depth, renderer) + rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3))) + rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100)) + + assert jnp.abs(normal).sum() > 0