#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>

from argparse import ArgumentParser

import torch


def print_dictionary(filename):
    """
    查看模型参数名称及维度 [命令行模块接口]

    :param filename: pytorch 模型文件位置

    Example::

        python -m dlab.debug.torch_model path/to/some.pytorch.model.file

    """
    status_dict = torch.load(filename)
    for name, tensor in status_dict.items():
        print(name, tensor.shape)


if __name__ == '__main__':
    parser = ArgumentParser(description='Show the pytorch model information.')
    parser.add_argument('filename', type=str, help='path to model file.')
    args = parser.parse_args()
    print_dictionary(args.filename)