| | #include "Encoder.hpp" |
| | #include "DecoderMain.hpp" |
| | #include "DecoderLoop.hpp" |
| |
|
| | #include <stdio.h> |
| | #include <ctime> |
| | #include <sys/time.h> |
| |
|
| | #include <ax_sys_api.h> |
| |
|
| | static double get_current_time() |
| | { |
| | struct timeval tv; |
| | gettimeofday(&tv, NULL); |
| |
|
| | return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; |
| | } |
| |
|
| | int main(int argc, char** argv) { |
| | int ret = AX_SYS_Init(); |
| | if (0 != ret) { |
| | fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret); |
| | return -1; |
| | } |
| |
|
| | AX_ENGINE_NPU_ATTR_T npu_attr; |
| | memset(&npu_attr, 0, sizeof(npu_attr)); |
| | npu_attr.eHardMode = static_cast<AX_ENGINE_NPU_MODE_T>(0); |
| | ret = AX_ENGINE_Init(&npu_attr); |
| | if (0 != ret) { |
| | fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); |
| | return -1; |
| | } |
| |
|
| | Encoder encoder; |
| | DecoderMain decoder_main; |
| | DecoderLoop decoder_loop; |
| |
|
| | double start, end; |
| | double whole_start, whole_end; |
| |
|
| | start = get_current_time(); |
| | if (0 != encoder.Init("../axmodel/encoder.axmodel")) { |
| | printf("Init encoder failed!\n"); |
| | return -1; |
| | } |
| | end = get_current_time(); |
| | printf("Load encoder take %.2fms\n", end - start); |
| |
|
| | start = get_current_time(); |
| | if (0 != decoder_main.Init("../axmodel/decoder_main.axmodel")) { |
| | printf("Init decoder_main failed!\n"); |
| | return -1; |
| | } |
| | end = get_current_time(); |
| | printf("Load decoder_main take %.2fms\n", end - start); |
| |
|
| | start = get_current_time(); |
| | if (0 != decoder_loop.Init("../axmodel/decoder_loop.axmodel")) { |
| | printf("Init decoder_loop failed!\n"); |
| | return -1; |
| | } |
| | end = get_current_time(); |
| | printf("Load decoder_loop take %.2fms\n", end - start); |
| |
|
| | std::vector<float> encoder_inputs(encoder.GetInputSize(0) / sizeof(float)); |
| | std::vector<float> encoder_input_lengths(encoder.GetInputSize(1) / sizeof(float)); |
| | encoder_input_lengths[0] = 100; |
| |
|
| | std::vector<float> n_layer_cross_k(encoder.GetOutputSize(0) / sizeof(float)); |
| | std::vector<float> n_layer_cross_v(encoder.GetOutputSize(1) / sizeof(float)); |
| | std::vector<float> cross_attn_mask(encoder.GetOutputSize(2) / sizeof(float)); |
| |
|
| | start = get_current_time(); |
| | whole_start = start; |
| | encoder.SetInput(encoder_inputs.data(), 0); |
| | encoder.SetInput(encoder_input_lengths.data(), 1); |
| | encoder.Run(); |
| | |
| | |
| | |
| | end = get_current_time(); |
| | printf("Run encoder take %.2fms\n", end - start); |
| |
|
| | std::vector<int> tokens(decoder_main.GetInputSize(0) / sizeof(int)); |
| |
|
| | std::vector<int> logits(decoder_main.GetOutputSize(0) / sizeof(int)); |
| | std::vector<float> n_layer_self_k_cache(decoder_main.GetOutputSize(1) / sizeof(float)); |
| | std::vector<float> n_layer_self_v_cache(decoder_main.GetOutputSize(2) / sizeof(float)); |
| |
|
| | start = get_current_time(); |
| | decoder_main.SetInput(tokens.data(), 0); |
| | |
| | |
| | |
| | decoder_main.SetInput(n_layer_cross_k.data(), 1); |
| | decoder_main.SetInput(n_layer_cross_v.data(), 2); |
| | decoder_main.SetInput(cross_attn_mask.data(), 3); |
| | decoder_main.Run(); |
| | decoder_main.GetOutput(logits.data(), 0); |
| | |
| | |
| | end = get_current_time(); |
| | printf("Run decoder_main take %.2fms\n", end - start); |
| |
|
| | std::vector<float> pe(decoder_loop.GetOutputSize(5) / sizeof(float)); |
| | std::vector<float> self_attn_mask(decoder_loop.GetOutputSize(6) / sizeof(float)); |
| |
|
| | decoder_loop.SetInput(n_layer_cross_k.data(), 3); |
| | decoder_loop.SetInput(n_layer_cross_v.data(), 4); |
| | for (int i = 0; i < 14; i++) { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | start = get_current_time(); |
| | decoder_loop.SetInput(tokens.data(), 0); |
| | decoder_loop.SetInput(decoder_loop.GetOutputPtr(1), 1); |
| | decoder_loop.SetInput(decoder_loop.GetOutputPtr(2), 2); |
| | |
| | |
| | |
| | decoder_loop.Run(); |
| | decoder_loop.GetOutput(logits.data(), 0); |
| | |
| | |
| | end = get_current_time(); |
| | printf("Run decoder_loop take %.2fms\n", end - start); |
| | } |
| |
|
| | whole_end = get_current_time(); |
| | printf("Whole duration %.2fms\n", whole_end - whole_start); |
| | printf("RTF: %.4f\n", (whole_end - whole_start) / 4000.0); |
| |
|
| | return 0; |
| | } |