import os
import json
import torch
import numpy as np
import smplx


def batch_orth_proj_idrot(X, camera):
    """
    X is N x num_points x 3
    camera is N x 3
    same as applying orth_proj_idrot to each N
    """

    # TODO check X dim size.

    # X_trans is (N, num_points, 2)
    X_trans = X[:, :, :2] + camera[:, None, 1:]
    return camera[:, None, 0:1] * X_trans


if __name__ == '__main__':
    """
    A sample code to generate image-based 2d joints and human-centered 3d joints and verts from the 85D SMPL parameters
    
    im_j2d: 24x2
    hc_j3d: 24x3
    hc_verts: 6890x3
    """

    # first download the SMPL neutral model, put it as /smpl/models/smpl/SMPL_NEUTRAL.pkl
    smpl = smplx.create('smplx/models', model_type='smpl',
                         gender='neutral', use_face_contour=False,
                         num_betas=10,
                         num_expression_coeffs=10,
                         ext='npz',
                         age='adult')

    path = '00001/000_00'
    files = sorted(list(os.listdir(path)))
    for _, _f in enumerate(files):
        name = os.path.join(path, _f)
        with open(name, 'r') as f:
            data = json.load(f)
            thetas = np.asarray(data["theta"])

        thetas = torch.from_numpy(thetas.astype(np.float32)).cpu()
        cam = thetas[0:3].reshape(1, 3)
        rr = thetas[3:6].reshape(1, 3)
        pose_69d = thetas[6:75].reshape(1, 69)
        shape = thetas[75:].reshape(1, 10)

        # image-based 2d joints
        output = smpl(betas=shape, body_pose=pose_69d, global_orient=rr, return_verts=True)
        verts = output.vertices # 6890x3
        j3d = output.joints
        j3d_24 = j3d[:, 0:24, :]
        im_j2d = batch_orth_proj_idrot(j3d_24, cam)

        # human-centered 3d joints and verts
        rr = 0 * rr  # close RR parameter
        output = smpl(betas=shape, body_pose=pose_69d, global_orient=rr, return_verts=True)
        hc_verts = output.vertices  # 6890x3
        j3d = output.joints
        hc_j3d = j3d[:, 0:24, :]