human_mesh_recovery.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 # -*- coding:utf-8 -*-
00003 
00004 # Ros Wrapper of Human Mesh Recovery
00005 #    See: End-to-end Recovery of Human Shape and Pose
00006 #         https://akanazawa.github.io/hmr/
00007 
00008 import chainer
00009 import chainer.functions as F
00010 from chainer import Variable
00011 import cv2
00012 import numpy as np
00013 import pylab as plt  # NOQA
00014 
00015 import tf
00016 import cv_bridge
00017 import message_filters
00018 import rospy
00019 from jsk_topic_tools import ConnectionBasedTransport
00020 from geometry_msgs.msg import Pose
00021 from geometry_msgs.msg import Point
00022 from geometry_msgs.msg import Quaternion
00023 from jsk_recognition_msgs.msg import PeoplePose
00024 from jsk_recognition_msgs.msg import PeoplePoseArray
00025 from sensor_msgs.msg import Image
00026 
00027 from hmr.smpl import SMPL
00028 from hmr.net import EncoderFC3Dropout
00029 from hmr.resnet_v2_50 import ResNet_v2_50
00030 
00031 
00032 def format_pose_msg(person_pose):
00033     key_points = []
00034     for pose, score in zip(person_pose.poses, person_pose.scores):
00035         key_points.append(pose.position.x)
00036         key_points.append(pose.position.y)
00037         key_points.append(score)
00038     return np.array(key_points, 'f').reshape(-1, 3)
00039 
00040 
00041 def resize_img(img, scale_factor):
00042     new_size = (np.floor(np.array(img.shape[0:2]) * scale_factor)).astype(int)
00043     new_img = cv2.resize(img, (new_size[1], new_size[0]))
00044     # This is scale factor of [height, width] i.e. [y, x]
00045     actual_factor = [
00046         new_size[0] / float(img.shape[0]), new_size[1] / float(img.shape[1])
00047     ]
00048     return new_img, actual_factor
00049 
00050 
00051 def scale_and_crop(image, scale, center, img_size):
00052     image_scaled, scale_factors = resize_img(image, scale)
00053     # Swap so it's [x, y]
00054     scale_factors = [scale_factors[1], scale_factors[0]]
00055     center_scaled = np.round(center * scale_factors).astype(np.int)
00056 
00057     margin = int(img_size / 2)
00058     image_pad = np.pad(
00059         image_scaled, ((margin, ), (margin, ), (0, )), mode='edge')
00060     center_pad = center_scaled + margin
00061     # figure out starting point
00062     start_pt = center_pad - margin
00063     end_pt = center_pad + margin
00064     # crop:
00065     crop = image_pad[start_pt[1]:end_pt[1], start_pt[0]:end_pt[0], :]
00066     proc_param = {
00067         'scale': scale,
00068         'start_pt': start_pt,
00069         'end_pt': end_pt,
00070         'img_size': img_size
00071     }
00072 
00073     return crop, proc_param
00074 
00075 
00076 def get_bbox(key_points, vis_thr=0.2):
00077     # Pick the most confident detection.
00078     vis = key_points[:, 2] > vis_thr
00079     vis_kp = key_points[vis, :2]
00080     if len(vis_kp) == 0:
00081         return False, False
00082     min_pt = np.min(vis_kp, axis=0)
00083     max_pt = np.max(vis_kp, axis=0)
00084     person_height = np.linalg.norm(max_pt - min_pt)
00085     if person_height == 0:
00086         return False, False
00087     center = (min_pt + max_pt) / 2.
00088     scale = 150. / person_height
00089     return scale, center
00090 
00091 
00092 def preprocess_image(img, key_points=None, img_size=224):
00093     if key_points is None:
00094         scale = 1.
00095         center = np.round(np.array(img.shape[:2]) / 2).astype(int)
00096         # image center in (x,y)
00097         center = center[::-1]
00098     else:
00099         scale, center = get_bbox(key_points, vis_thr=0.1)
00100         if scale is False:
00101             scale = 1.
00102             center = np.round(np.array(img.shape[:2]) / 2).astype(int)
00103             # image center in (x,y)
00104             center = center[::-1]
00105     crop_img, proc_param = scale_and_crop(img, scale, center,
00106                                           img_size)
00107 
00108     # Normalize image to [-1, 1]
00109     crop_img = 2 * ((crop_img / 255.) - 0.5)
00110     return crop_img, proc_param
00111 
00112 
00113 mean = np.array([[
00114     0.90365213, -0.00383353,  0.03301106,  3.14986515, -0.01883755,
00115     0.16895422, -0.15615709, -0.0058559,  0.07191881, -0.18924442,
00116     -0.04396844, -0.05114707,  0.24385466,  0.00881136,  0.02384637,
00117     0.2066803, -0.10190887, -0.03373535,  0.27340922,  0.00637481,
00118     0.07408072, -0.03409823, -0.00971786,  0.03841642,  0.0191336,
00119     0.10812955, -0.06782207, -0.08026548, -0.18373352,  0.16556455,
00120     0.03735897, -0.02497507,  0.02688527, -0.18802814,  0.17772846,
00121     0.13135587,  0.01438429,  0.15891947, -0.2279436, -0.07497088,
00122     0.05169746,  0.08784129,  0.02147929,  0.02320284, -0.42375749,
00123     -0.04963749,  0.08627309,  0.47963148,  0.26528436, -0.1028522,
00124     -0.02501041,  0.05762934, -0.26270828, -0.8297376,  0.13903582,
00125     0.30412629,  0.79824799,  0.12842464, -0.64139324,  0.16469972,
00126     -0.08669609,  0.55955994, -0.16742738, -0.03153928, -0.02245264,
00127     -0.02357809,  0.02061746,  0.02320515,  0.00869796, -0.1655257,
00128     -0.07094092, -0.1663706, -0.10953037,  0.11675739,  0.20495811,
00129     0.10592803,  0.14583197, -0.31755996,  0.13645983,  0.28833047,
00130     0.06303538,  0.48629287,  0.23359743, -0.02812484,  0.23948504]], 'f')
00131 
00132 
00133 class HumanMeshRecovery(ConnectionBasedTransport):
00134 
00135     def __init__(self):
00136         super(self.__class__, self).__init__()
00137         self.gpu = rospy.get_param('~gpu', -1)  # -1 is cpu mode
00138         self.with_people_pose = rospy.get_param('~with_people_pose', False)
00139         self.num_stage = rospy.get_param('~num_stage', 3)
00140 
00141         self.smpl = SMPL()
00142         self.encoder_fc3_model = EncoderFC3Dropout()
00143         self.resnet_v2 = ResNet_v2_50()
00144         self._load_model()
00145 
00146         self.br = cv_bridge.CvBridge()
00147         self.pose_pub = self.advertise(
00148             '~output/pose', PeoplePoseArray, queue_size=1)
00149 
00150     def _load_model(self):
00151         smpl_model_file = rospy.get_param('~smpl_model_file')
00152         chainer.serializers.load_npz(smpl_model_file, self.smpl)
00153         encoder_fc3_model_file = rospy.get_param('~encoder_model_file')
00154         chainer.serializers.load_npz(
00155             encoder_fc3_model_file, self.encoder_fc3_model)
00156         resnet_v2_50_model_file = rospy.get_param('~resnet_v2_50_model_file')
00157         chainer.serializers.load_npz(resnet_v2_50_model_file, self.resnet_v2)
00158 
00159         rospy.loginfo('Finished loading trained model')
00160         if self.gpu >= 0:
00161             chainer.cuda.get_device_from_id(self.gpu).use()
00162             self.smpl.to_gpu()
00163             self.encoder_fc3_model.to_gpu()
00164             self.resnet_v2.to_gpu()
00165         chainer.global_config.train = False
00166         chainer.global_config.enable_backprop = False
00167 
00168     def subscribe(self):
00169         if self.with_people_pose:
00170             queue_size = rospy.get_param('~queue_size', 10)
00171             sub_img = message_filters.Subscriber(
00172                 '~input', Image, queue_size=queue_size, buff_size=2**24)
00173             sub_pose = message_filters.Subscriber(
00174                 '~input/pose', PeoplePoseArray,
00175                 queue_size=queue_size, buff_size=2**24)
00176             self.subs = [sub_img, sub_pose]
00177 
00178             if rospy.get_param('~approximate_sync', False):
00179                 slop = rospy.get_param('~slop', 0.1)
00180                 sync = message_filters.ApproximateTimeSynchronizer(
00181                     fs=self.subs, queue_size=queue_size, slop=slop)
00182             else:
00183                 sync = message_filters.TimeSynchronizer(
00184                     fs=self.subs, queue_size=queue_size)
00185             sync.registerCallback(self._cb_with_pose)
00186         else:
00187             sub_img = rospy.Subscriber(
00188                 '~input', Image, self._cb,
00189                 queue_size=1, buff_size=2**24)
00190             self.subs = [sub_img]
00191 
00192     def unsubscribe(self):
00193         for sub in self.subs:
00194             sub.unregister()
00195 
00196     def _cb(self, img_msg):
00197         br = self.br
00198         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00199         img, _ = preprocess_image(img)
00200         imgs = img.transpose(2, 0, 1)[None, ]
00201         verts, Js, Rs, A, cams, poses, shapes = self.pose_estimate(imgs)
00202 
00203         people_pose_msg = self._create_people_pose_array_msgs(
00204             chainer.cuda.to_cpu(A.data), img_msg.header)
00205         self.pose_pub.publish(people_pose_msg)
00206 
00207     def _cb_with_pose(self, img_msg, people_pose_msg):
00208         br = self.br
00209         img = br.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
00210 
00211         imgs = []
00212         for person_pose in people_pose_msg.poses:
00213             key_points = format_pose_msg(person_pose)
00214             crop_img, _ = preprocess_image(img, key_points)
00215             imgs.append(crop_img)
00216         if len(imgs) == 0:
00217             img, _ = preprocess_image(img)
00218             imgs = np.array(img[None, ], 'f').transpose(0, 3, 1, 2)
00219         else:
00220             imgs = np.array(imgs, 'f').transpose(0, 3, 1, 2)
00221         verts, Js, Rs, A, cams, poses, shapes = self.pose_estimate(imgs)
00222 
00223         people_pose_msg = self._create_people_pose_array_msgs(
00224             chainer.cuda.to_cpu(A.data), img_msg.header)
00225         self.pose_pub.publish(people_pose_msg)
00226 
00227     def _create_people_pose_array_msgs(self, people_joint_positions, header):
00228         people_pose_msg = PeoplePoseArray(header=header)
00229         for i, person_joint_positions in enumerate(people_joint_positions):
00230             pose_msg = PeoplePose()
00231             for joint_pose in person_joint_positions:
00232                 pose_msg.limb_names.append(str(i))
00233                 pose_msg.scores.append(0.0)
00234                 q_xyzw = tf.transformations.quaternion_from_matrix(joint_pose)
00235                 pose_msg.poses.append(
00236                     Pose(position=Point(
00237                         x=joint_pose[0, 3],
00238                         y=joint_pose[1, 3],
00239                         z=joint_pose[2, 3]),
00240                         orientation=Quaternion(
00241                         x=q_xyzw[0],
00242                         y=q_xyzw[1],
00243                         z=q_xyzw[2],
00244                         w=q_xyzw[3])))
00245             people_pose_msg.poses.append(pose_msg)
00246         return people_pose_msg
00247 
00248     def pose_estimate(self, imgs):
00249         batch_size = imgs.shape[0]
00250         imgs = Variable(self.resnet_v2.xp.array(imgs, 'f'))
00251         img_feat = self.resnet_v2(imgs).reshape(batch_size, -1)
00252 
00253         theta_prev = F.tile(
00254             Variable(self.encoder_fc3_model.xp.array(mean, 'f')),
00255             (batch_size, 1))
00256         num_cam = 3
00257         num_theta = 72
00258         for i in range(self.num_stage):
00259             state = F.concat([img_feat, theta_prev], axis=1)
00260             delta_theta = self.encoder_fc3_model(state)
00261             theta_here = theta_prev + delta_theta
00262             # cam = N x 3, pose N x self.num_theta, shape: N x 10
00263             cams = theta_here[:, :num_cam]
00264             poses = theta_here[:, num_cam:(num_cam + num_theta)]
00265             shapes = theta_here[:, (num_cam + num_theta):]
00266 
00267             verts, Js, Rs, A = self.smpl(shapes, poses)
00268             # Project to 2D!
00269             # pred_kp = batch_orth_proj_idrot(
00270             #     Js, cams, name='proj_2d_stage%d' % i)
00271             theta_prev = theta_here
00272         return verts, Js, Rs, A, cams, poses, shapes
00273 
00274 
00275 if __name__ == '__main__':
00276     rospy.init_node('human_mesh_recovery')
00277     HumanMeshRecovery()
00278     rospy.spin()


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Tue Jul 2 2019 19:41:07