00001 from __future__ import absolute_import
00002 from __future__ import division
00003 from __future__ import print_function
00004
00005 from chainer import Parameter
00006 from chainer import Variable
00007 import chainer
00008 import chainer.functions as F
00009 import numpy as np
00010
00011
00012 def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False):
00013 """
00014 Computes absolute joint locations given pose.
00015
00016 rotate_base: if True, rotates the global rotation by 90 deg in x axis.
00017 if False, this is the original SMPL coordinate.
00018
00019 Args:
00020 Rs: N x 24 x 3 x 3 rotation vector of K joints
00021 Js: N x 24 x 3, joint locations before posing
00022 parent: 24 holding the parent id for each index
00023
00024 Returns
00025 new_J : `Tensor`: N x 24 x 3 location of absolute joints
00026 A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS.
00027 """
00028 xp = Rs.xp
00029 N = Rs.shape[0]
00030 if rotate_base:
00031 print('Flipping the SMPL coordinate frame!!!!')
00032 rot_x = Variable(
00033 [[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=Rs.dtype)
00034 rot_x = F.reshape(F.tile(rot_x, [N, 1]), [N, 3, 3])
00035 root_rotation = F.matmul(Rs[:, 0, :, :], rot_x)
00036 else:
00037 root_rotation = Rs[:, 0, :, :]
00038
00039
00040 Js = F.expand_dims(Js, -1)
00041
00042 def make_A(R, t, name=None):
00043
00044 R_homo = F.pad(R, [[0, 0], [0, 1], [0, 0]], 'constant')
00045 t_homo = F.concat([t, xp.ones([N, 1, 1], 'f')], 1)
00046 return F.concat([R_homo, t_homo], 2)
00047
00048 A0 = make_A(root_rotation, Js[:, 0])
00049 results = [A0]
00050 for i in range(1, parent.shape[0]):
00051 j_here = Js[:, i] - Js[:, parent[i]]
00052 A_here = make_A(Rs[:, i], j_here)
00053 res_here = F.matmul(
00054 results[parent[i]], A_here)
00055 results.append(res_here)
00056
00057
00058 results = F.stack(results, axis=1)
00059
00060 new_J = results[:, :, :3, 3]
00061
00062
00063
00064
00065
00066 Js_w0 = F.concat([Js, xp.zeros([N, 24, 1, 1], 'f')], 2)
00067 init_bone = F.matmul(results, Js_w0)
00068
00069 init_bone = F.pad(init_bone, [[0, 0], [0, 0], [0, 0], [3, 0]], 'constant')
00070 A = results - init_bone
00071
00072 return new_J, results
00073
00074
00075 def batch_skew(vec, batch_size=None):
00076 """
00077 vec is N x 3, batch_size is int
00078
00079 returns N x 3 x 3. Skew_sym version of each matrix.
00080 """
00081 xp = vec.xp
00082 if batch_size is None:
00083 batch_size = vec.shape[0]
00084 col_inds = xp.array([1, 2, 3, 5, 6, 7])
00085 indices = F.reshape(
00086 F.repeat(col_inds.reshape(1, -1), batch_size, axis=0) +
00087 F.repeat(F.reshape(xp.arange(0, batch_size) * 9, [-1, 1]), 6, axis=1),
00088 [-1, 1])
00089 updates = F.reshape(
00090 F.stack(
00091 [-vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1],
00092 vec[:, 0]], axis=1), [-1])
00093 res = Variable(xp.zeros((batch_size * 3 * 3), 'f'))
00094 res.data[indices.reshape(-1).data] = updates.data
00095 res = F.reshape(res, [batch_size, 3, 3])
00096 return res
00097
00098
00099 def batch_rodrigues(theta):
00100 """
00101 Theta is N x 3
00102 """
00103 batch_size = theta.shape[0]
00104 xp = theta.xp
00105
00106 angle = F.expand_dims(F.sqrt(F.batch_l2_norm_squared(theta + 1e-8)), -1)
00107 r = F.expand_dims(theta / F.tile(angle, 3), -1)
00108
00109 angle = F.expand_dims(angle, -1)
00110 cos = F.cos(angle)
00111 sin = F.sin(angle)
00112 cos = F.tile(cos, (3, 3))
00113 sin = F.tile(sin, (3, 3))
00114
00115 outer = F.matmul(r, r, transb=True)
00116
00117 eyes = F.tile(F.expand_dims(
00118 Variable(xp.array(xp.eye(3), 'f')), 0), (batch_size, 1, 1))
00119 R = cos * eyes + (1 - cos) * outer + sin * batch_skew(r, batch_size)
00120 return R
00121
00122
00123 class SMPL(chainer.Chain):
00124
00125 def __init__(self):
00126 super(SMPL, self).__init__()
00127 self.parents = np.array([
00128 4294967295, 0, 0, 0, 1,
00129 2, 3, 4, 5, 6, 7, 8, 9, 9, 9,
00130 12, 13, 14, 16, 17, 18, 19, 20, 21], 'i')
00131
00132 with self.init_scope():
00133 self.v_template = Parameter(0, (6890, 3))
00134 self.size = [self.v_template.shape[0], 3]
00135 self.shapedirs = Parameter(0, (10, 20670))
00136 self.J_regressor = Parameter(0, (6890, 24))
00137 self.posedirs = Parameter(0, (207, 20670))
00138 self.weights = Parameter(0, (6890, 24))
00139 self.joint_regressor = Parameter(0, (6890, 19))
00140
00141 def __call__(self, beta, theta, get_skin=False, with_a=False):
00142 batch_size = beta.shape[0]
00143
00144
00145
00146 self.beta_shapedirs = F.matmul(beta, self.shapedirs)
00147 v_shaped = F.reshape(
00148 F.matmul(beta, self.shapedirs),
00149 [-1, self.size[0], self.size[1]]) + \
00150 F.repeat(self.v_template[None, ], batch_size, axis=0)
00151 self.v_shaped = v_shaped
00152
00153
00154 Jx = F.matmul(v_shaped[:, :, 0], self.J_regressor)
00155 Jy = F.matmul(v_shaped[:, :, 1], self.J_regressor)
00156 Jz = F.matmul(v_shaped[:, :, 2], self.J_regressor)
00157 J = F.stack([Jx, Jy, Jz], axis=2)
00158
00159 self.J = J
00160
00161
00162
00163 Rs = F.reshape(
00164 batch_rodrigues(F.reshape(theta, [-1, 3])), [-1, 24, 3, 3])
00165 self.Rs = Rs
00166
00167 pose_feature = F.reshape(Rs[:, 1:, :, :] -
00168 F.repeat(F.repeat(Variable(self.xp.array(self.xp.eye(3), 'f'))[
00169 None, ], 23, axis=0)[None, ], batch_size, axis=0),
00170 [-1, 207])
00171 self.pose_feature = pose_feature
00172
00173
00174 v_posed = F.reshape(
00175 F.matmul(pose_feature, self.posedirs),
00176 [-1, self.size[0], self.size[1]]) + v_shaped
00177
00178
00179 self.J_transformed, A = batch_global_rigid_transformation(
00180 Rs, J, self.parents)
00181
00182
00183
00184 W = F.reshape(
00185 F.tile(self.weights, (batch_size, 1)), [batch_size, -1, 24])
00186
00187 T = F.reshape(
00188 F.matmul(W, F.reshape(A, [batch_size, 24, 16])),
00189 [batch_size, -1, 4, 4])
00190 v_posed_homo = F.concat(
00191 [v_posed, self.xp.ones([batch_size, v_posed.shape[1], 1], 'f')], 2)
00192 v_homo = F.matmul(T, F.expand_dims(v_posed_homo, -1))
00193
00194 verts = v_homo[:, :, :3, 0]
00195
00196
00197 joint_x = F.matmul(verts[:, :, 0], self.joint_regressor)
00198 joint_y = F.matmul(verts[:, :, 1], self.joint_regressor)
00199 joint_z = F.matmul(verts[:, :, 2], self.joint_regressor)
00200 joints = F.stack([joint_x, joint_y, joint_z], axis=2)
00201
00202 return verts, joints, Rs, A