From a8043b89f4f3f93104245e02be1611e5868e8342 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 23 May 2025 08:09:49 -0400 Subject: [PATCH] [REFACTOR][FFI] Phase out legacy C API This PR phases out the legacy c api in favor of the new FFI C API. Also removes the redirection sccafolding for registry.h - include => include - include => include - TVM_REGISTER_GLOBAL => TVM_FFI_REGISTER_GLOBAL The cleanup will greatly simplify the overall FFI surface of the project and allows us to move towards an unified clean API based on tvm ffi. --- 3rdparty/cutlass_fpA_intB_gemm | 2 +- .../app/src/main/jni/tvm_runtime.h | 4 +- apps/cpp_rpc/rpc_env.cc | 4 +- apps/cpp_rpc/rpc_env.h | 2 +- apps/cpp_rpc/rpc_server.cc | 4 +- apps/cpp_rpc/rpc_server.h | 2 +- apps/hexagon_launcher/launcher_core.cc | 2 +- apps/ios_rpc/tvmrpc/RPCServer.mm | 2 +- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 8 +- docs/arch/device_target_interactions.rst | 10 +- docs/arch/pass_infra.rst | 2 +- docs/arch/runtime.rst | 8 +- ffi/include/tvm/ffi/c_api.h | 7 + ffi/include/tvm/ffi/function.h | 2 +- include/tvm/ir/op.h | 2 +- include/tvm/ir/source_map.h | 2 +- include/tvm/node/node.h | 2 +- include/tvm/node/reflection.h | 2 +- include/tvm/node/serialization.h | 2 +- include/tvm/relax/exec_builder.h | 2 +- include/tvm/relax/type.h | 2 +- include/tvm/runtime/base.h | 59 ++ include/tvm/runtime/builtin_fp16.h | 2 +- include/tvm/runtime/c_backend_api.h | 34 +- include/tvm/runtime/c_runtime_api.h | 732 ---------------- include/tvm/runtime/data_type.h | 5 +- include/tvm/runtime/device_api.h | 42 +- include/tvm/runtime/disco/cuda_ipc_memory.h | 2 +- include/tvm/runtime/logging.h | 2 +- include/tvm/runtime/memory/memory_manager.h | 2 +- include/tvm/runtime/module.h | 2 +- include/tvm/runtime/ndarray.h | 48 +- include/tvm/runtime/nvtx.h | 2 +- include/tvm/runtime/object.h | 2 +- include/tvm/runtime/packed_func.h | 282 +----- include/tvm/runtime/profiling.h | 6 +- include/tvm/runtime/registry.h | 102 --- include/tvm/runtime/relax_vm/executable.h | 2 +- .../runtime/relax_vm/ndarray_cache_support.h | 2 +- include/tvm/runtime/serializer.h | 2 +- include/tvm/support/parallel_for.h | 2 +- include/tvm/tir/builtin.h | 20 +- include/tvm/tir/expr.h | 2 +- include/tvm/tir/transform.h | 4 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +- python/tvm/relax/frontend/nn/op.py | 2 +- python/tvm/runtime/_ffi_api.py | 4 +- python/tvm/runtime/_ffi_node_api.py | 6 +- src/arith/analyzer.cc | 4 +- src/arith/bound_deducer.cc | 4 +- src/arith/const_int_bound.cc | 4 +- src/arith/detect_common_subexpr.cc | 2 +- src/arith/detect_linear_equation.cc | 6 +- src/arith/domain_touched.cc | 6 +- src/arith/int_constraints.cc | 13 +- src/arith/int_set.cc | 30 +- src/arith/iter_affine_map.cc | 18 +- src/arith/modular_set.cc | 4 +- src/arith/narrow_predicate_expression.cc | 5 +- src/arith/presburger_set.cc | 4 +- src/arith/solve_linear_equation.cc | 4 +- src/arith/solve_linear_inequality.cc | 8 +- src/contrib/msc/core/ir/graph.cc | 101 +-- src/contrib/msc/core/ir/graph_builder.cc | 4 +- src/contrib/msc/core/ir/plugin.cc | 8 +- .../msc/core/transform/bind_named_params.cc | 2 +- src/contrib/msc/core/transform/bind_shape.cc | 2 +- src/contrib/msc/core/transform/fuse_tuple.cc | 2 +- .../msc/core/transform/inline_params.cc | 2 +- .../msc/core/transform/set_byoc_attrs.cc | 2 +- .../msc/core/transform/set_expr_layout.cc | 2 +- .../msc/core/transform/set_expr_name.cc | 2 +- src/contrib/msc/core/utils.cc | 12 +- .../msc/framework/tensorflow/codegen.cc | 2 +- src/contrib/msc/framework/tensorrt/codegen.cc | 6 +- .../framework/tensorrt/transform_tensorrt.cc | 2 +- src/contrib/msc/framework/torch/codegen.cc | 2 +- src/contrib/msc/framework/tvm/codegen.cc | 2 +- src/contrib/msc/plugin/tensorrt_codegen.cc | 2 +- src/contrib/msc/plugin/torch_codegen.cc | 2 +- src/contrib/msc/plugin/tvm_codegen.cc | 8 +- src/ir/analysis.cc | 2 +- src/ir/apply_pass_to_function.cc | 4 +- src/ir/attrs.cc | 6 +- src/ir/diagnostic.cc | 22 +- src/ir/env_func.cc | 8 +- src/ir/expr.cc | 14 +- src/ir/function.cc | 14 +- src/ir/global_info.cc | 9 +- src/ir/global_var_supply.cc | 17 +- src/ir/instrument.cc | 8 +- src/ir/module.cc | 46 +- src/ir/name_supply.cc | 11 +- src/ir/op.cc | 26 +- src/ir/replace_global_vars.cc | 4 +- src/ir/source_map.cc | 23 +- src/ir/transform.cc | 46 +- src/ir/type.cc | 15 +- src/meta_schedule/arg_info.cc | 11 +- src/meta_schedule/builder/builder.cc | 8 +- src/meta_schedule/cost_model/cost_model.cc | 11 +- src/meta_schedule/database/database.cc | 45 +- src/meta_schedule/database/json_database.cc | 3 +- src/meta_schedule/database/memory_database.cc | 2 +- .../database/ordered_union_database.cc | 2 +- .../database/schedule_fn_database.cc | 2 +- src/meta_schedule/database/union_database.cc | 3 +- src/meta_schedule/extracted_task.cc | 2 +- .../feature_extractor/feature_extractor.cc | 4 +- .../feature_extractor/per_store_feature.cc | 2 +- .../measure_callback/add_to_database.cc | 2 +- .../measure_callback/measure_callback.cc | 6 +- .../measure_callback/remove_build_artifact.cc | 2 +- .../measure_callback/update_cost_model.cc | 2 +- .../mutator/mutate_compute_location.cc | 2 +- src/meta_schedule/mutator/mutate_parallel.cc | 3 +- .../mutator/mutate_thread_binding.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 3 +- src/meta_schedule/mutator/mutate_unroll.cc | 2 +- src/meta_schedule/mutator/mutator.cc | 17 +- .../disallow_async_strided_mem_copy.cc | 2 +- .../postproc/disallow_dynamic_loop.cc | 2 +- src/meta_schedule/postproc/postproc.cc | 16 +- .../postproc/rewrite_cooperative_fetch.cc | 2 +- src/meta_schedule/postproc/rewrite_layout.cc | 3 +- .../rewrite_parallel_vectorize_unroll.cc | 2 +- .../postproc/rewrite_reduction_block.cc | 2 +- .../postproc/rewrite_tensorize.cc | 2 +- .../postproc/rewrite_unbound_block.cc | 2 +- src/meta_schedule/postproc/verify_gpu_code.cc | 3 +- .../postproc/verify_vtcm_limit.cc | 2 +- src/meta_schedule/profiler.cc | 14 +- src/meta_schedule/runner/runner.cc | 15 +- src/meta_schedule/schedule/cpu/winograd.cc | 8 +- src/meta_schedule/schedule/cuda/winograd.cc | 8 +- .../schedule_rule/add_rfactor.cc | 2 +- .../schedule_rule/apply_custom_rule.cc | 2 +- src/meta_schedule/schedule_rule/auto_bind.cc | 3 +- .../schedule_rule/auto_inline.cc | 4 +- .../schedule_rule/cross_thread_reduction.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 2 +- .../multi_level_tiling_tensor_core.cc | 2 +- .../multi_level_tiling_wide_vector.cc | 2 +- .../multi_level_tiling_with_intrin.cc | 2 +- .../parallel_vectorize_unroll.cc | 2 +- .../schedule_rule/random_compute_location.cc | 2 +- .../schedule_rule/schedule_rule.cc | 20 +- .../search_strategy/evolutionary_search.cc | 6 +- .../search_strategy/replay_func.cc | 2 +- .../search_strategy/replay_trace.cc | 2 +- .../search_strategy/search_strategy.cc | 16 +- .../space_generator/post_order_apply.cc | 2 +- .../space_generator/schedule_fn.cc | 2 +- .../space_generator/space_generator.cc | 8 +- .../space_generator/space_generator_union.cc | 2 +- .../task_scheduler/gradient_based.cc | 2 +- .../task_scheduler/round_robin.cc | 2 +- .../task_scheduler/task_scheduler.cc | 15 +- src/meta_schedule/trace_apply.cc | 2 +- src/meta_schedule/tune_context.cc | 8 +- src/node/container_printing.cc | 2 +- src/node/object_path.cc | 24 +- src/node/reflection.cc | 8 +- src/node/repr_printer.cc | 6 +- src/node/script_printer.cc | 6 +- src/node/serialization.cc | 7 +- src/node/structural_equal.cc | 10 +- src/node/structural_hash.cc | 4 +- src/relax/analysis/analysis.cc | 10 +- .../analysis/computable_at_compile_time.cc | 2 +- src/relax/analysis/detect_recursion.cc | 2 +- src/relax/analysis/layout_transformation.cc | 2 +- src/relax/analysis/struct_info_analysis.cc | 22 +- src/relax/analysis/tir_op_pattern_kind.cc | 2 +- src/relax/analysis/udchain.cc | 2 +- src/relax/analysis/var2value.cc | 4 +- src/relax/analysis/well_formed.cc | 2 +- src/relax/backend/contrib/clml/codegen.cc | 7 +- src/relax/backend/contrib/cublas/codegen.cc | 2 +- src/relax/backend/contrib/cudnn/codegen.cc | 2 +- src/relax/backend/contrib/cutlass/codegen.cc | 4 +- src/relax/backend/contrib/dnnl/codegen.cc | 2 +- src/relax/backend/contrib/hipblas/codegen.cc | 2 +- src/relax/backend/contrib/nnapi/codegen.cc | 4 +- src/relax/backend/contrib/tensorrt/codegen.cc | 7 +- src/relax/backend/contrib/utils.cc | 2 +- src/relax/backend/pattern_registry.cc | 9 +- src/relax/backend/task_extraction.cc | 2 +- src/relax/backend/vm/codegen_vm.cc | 4 +- src/relax/backend/vm/codegen_vm_tir.cc | 2 +- src/relax/backend/vm/exec_builder.cc | 44 +- src/relax/backend/vm/lower_runtime_builtin.cc | 2 +- src/relax/backend/vm/vm_shape_lower.cc | 2 +- src/relax/distributed/global_info.cc | 2 +- src/relax/distributed/struct_info.cc | 10 +- .../transform/legalize_redistribute.cc | 2 +- .../distributed/transform/lower_distir.cc | 2 +- .../lower_global_view_to_local_view.cc | 2 +- .../transform/propagate_sharding.cc | 2 +- src/relax/ir/binding_rewrite.cc | 16 +- src/relax/ir/block_builder.cc | 40 +- src/relax/ir/dataflow_block_rewriter.cc | 4 +- src/relax/ir/dataflow_expr_rewriter.cc | 21 +- src/relax/ir/dataflow_pattern.cc | 74 +- src/relax/ir/emit_te.cc | 2 +- src/relax/ir/expr.cc | 62 +- src/relax/ir/expr_functor.cc | 2 +- src/relax/ir/py_expr_functor.cc | 50 +- src/relax/ir/struct_info.cc | 31 +- src/relax/ir/transform.cc | 6 +- src/relax/ir/type.cc | 12 +- src/relax/op/ccl/ccl.cc | 9 +- src/relax/op/distributed/distributed.cc | 8 +- src/relax/op/image/resize.cc | 2 +- src/relax/op/memory/view.cc | 7 +- src/relax/op/nn/attention.cc | 4 +- src/relax/op/nn/convolution.cc | 10 +- src/relax/op/nn/nn.cc | 28 +- src/relax/op/nn/pooling.cc | 18 +- src/relax/op/op.cc | 56 +- src/relax/op/op_common.h | 2 +- src/relax/op/tensor/binary.h | 2 +- src/relax/op/tensor/create.cc | 22 +- src/relax/op/tensor/datatype.cc | 4 +- src/relax/op/tensor/grad.cc | 14 +- src/relax/op/tensor/index.cc | 6 +- src/relax/op/tensor/linear_algebra.cc | 6 +- src/relax/op/tensor/manipulate.cc | 46 +- src/relax/op/tensor/qdq.cc | 4 +- src/relax/op/tensor/sampling.cc | 3 +- src/relax/op/tensor/search.cc | 4 +- src/relax/op/tensor/set.cc | 4 +- src/relax/op/tensor/sorting.cc | 6 +- src/relax/op/tensor/statistical.cc | 4 +- src/relax/op/tensor/statistical.h | 2 +- src/relax/op/tensor/ternary.cc | 2 +- src/relax/op/tensor/unary.cc | 2 +- src/relax/testing/transform.cc | 2 +- src/relax/training/utils.cc | 2 +- src/relax/transform/adjust_matmul_order.cc | 2 +- src/relax/transform/allocate_workspace.cc | 2 +- src/relax/transform/alter_op_impl.cc | 2 +- .../transform/annotate_tir_op_pattern.cc | 3 +- .../attach_attr_layout_free_buffers.cc | 2 +- src/relax/transform/attach_global_symbol.cc | 2 +- src/relax/transform/bind_params.cc | 4 +- src/relax/transform/bind_symbolic_vars.cc | 4 +- src/relax/transform/bundle_model_params.cc | 2 +- src/relax/transform/call_tir_rewrite.cc | 2 +- src/relax/transform/canonicalize_bindings.cc | 3 +- .../transform/combine_parallel_matmul.cc | 3 +- src/relax/transform/compute_prim_value.cc | 2 +- src/relax/transform/convert_dataflow.cc | 2 +- src/relax/transform/convert_layout.cc | 2 +- src/relax/transform/dataflow_inplace.cc | 10 +- src/relax/transform/dead_code_elimination.cc | 2 +- src/relax/transform/decompose_ops.cc | 4 +- .../transform/eliminate_common_subexpr.cc | 2 +- src/relax/transform/expand_matmul_of_sum.cc | 2 +- src/relax/transform/expand_tuple_arguments.cc | 3 +- src/relax/transform/few_shot_tuning.cc | 2 +- src/relax/transform/fold_constant.cc | 2 +- src/relax/transform/fuse_ops.cc | 6 +- src/relax/transform/fuse_tir.cc | 2 +- src/relax/transform/gradient.cc | 2 +- src/relax/transform/inline_functions.cc | 4 +- src/relax/transform/kill_after_last_use.cc | 2 +- src/relax/transform/lambda_lift.cc | 2 +- src/relax/transform/lazy_transform_params.cc | 4 +- src/relax/transform/legalize_ops.cc | 2 +- src/relax/transform/lift_transform_params.cc | 2 +- src/relax/transform/lower_alloc_tensor.cc | 2 +- .../transform/merge_composite_functions.cc | 2 +- src/relax/transform/meta_schedule.cc | 7 +- src/relax/transform/normalize.cc | 4 +- src/relax/transform/realize_vdevice.cc | 2 +- src/relax/transform/remove_purity_checking.cc | 3 +- src/relax/transform/remove_unused_outputs.cc | 2 +- .../transform/remove_unused_parameters.cc | 2 +- .../reorder_permute_dims_after_concat.cc | 2 +- .../transform/reorder_take_after_matmul.cc | 2 +- src/relax/transform/rewrite_cuda_graph.cc | 2 +- .../transform/rewrite_dataflow_reshape.cc | 2 +- src/relax/transform/run_codegen.cc | 2 +- .../transform/split_call_tir_by_pattern.cc | 3 +- .../transform/split_layout_rewrite_preproc.cc | 2 +- .../transform/static_plan_block_memory.cc | 3 +- src/relax/transform/to_mixed_precision.cc | 2 +- src/relax/transform/to_non_dataflow.cc | 2 +- src/relax/transform/topological_sort.cc | 2 +- src/relax/transform/tuning_api/database.cc | 26 +- src/relax/transform/tuning_api/primitives.cc | 39 +- .../transform/update_param_struct_info.cc | 3 +- src/relax/transform/update_vdevice.cc | 2 +- src/relax/utils.cc | 2 +- src/runtime/builtin_fp16.cc | 8 +- src/runtime/c_runtime_api.cc | 807 ------------------ src/runtime/const_loader_module.cc | 6 +- src/runtime/container.cc | 101 --- src/runtime/contrib/amx/amx_config.cc | 6 +- .../contrib/arm_compute_lib/acl_allocator.h | 2 +- .../contrib/arm_compute_lib/acl_runtime.cc | 6 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 6 +- src/runtime/contrib/cblas/cblas.cc | 8 +- src/runtime/contrib/cblas/dnnl_blas.cc | 4 +- src/runtime/contrib/cblas/gemm_common.h | 2 +- src/runtime/contrib/cblas/mkl.cc | 10 +- src/runtime/contrib/clml/clml_runtime.cc | 4 +- src/runtime/contrib/clml/clml_runtime.h | 2 +- src/runtime/contrib/coreml/coreml_runtime.mm | 7 +- src/runtime/contrib/cublas/cublas.cc | 8 +- .../contrib/cublas/cublas_json_runtime.cc | 6 +- src/runtime/contrib/cublas/cublas_utils.cc | 2 +- src/runtime/contrib/cudnn/conv_backward.cc | 10 +- src/runtime/contrib/cudnn/conv_forward.cc | 10 +- .../contrib/cudnn/cudnn_frontend/attention.cc | 2 +- .../contrib/cudnn/cudnn_frontend/attention.h | 2 +- .../contrib/cudnn/cudnn_json_runtime.cc | 6 +- src/runtime/contrib/cudnn/cudnn_utils.cc | 4 +- src/runtime/contrib/cudnn/softmax.cc | 6 +- src/runtime/contrib/curand/curand.cc | 6 +- .../contrib/curand/helper_cuda_kernels.h | 2 +- .../contrib/cutlass/fp16_group_gemm.cu | 4 +- .../cutlass/fp8_blockwise_scaled_gemm.cu | 6 +- src/runtime/contrib/cutlass/fp8_gemm.cu | 8 +- src/runtime/contrib/cutlass/fp8_group_gemm.cu | 8 +- .../contrib/cutlass/weight_preprocess.cc | 4 +- src/runtime/contrib/dnnl/dnnl.cc | 2 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 6 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 4 +- .../contrib/edgetpu/edgetpu_runtime.cc | 4 +- src/runtime/contrib/hipblas/hipblas.cc | 6 +- .../contrib/hipblas/hipblas_json_runtime.cc | 7 +- src/runtime/contrib/hipblas/hipblas_utils.cc | 2 +- src/runtime/contrib/miopen/conv_forward.cc | 6 +- src/runtime/contrib/miopen/miopen_utils.cc | 2 +- src/runtime/contrib/miopen/softmax.cc | 6 +- src/runtime/contrib/mps/conv.mm | 6 +- src/runtime/contrib/mps/gemm.mm | 2 +- src/runtime/contrib/mps/mps_utils.h | 2 +- src/runtime/contrib/mrvl/mrvl_hw_runtime.cc | 6 +- src/runtime/contrib/mrvl/mrvl_runtime.cc | 6 +- .../contrib/mrvl/mrvl_sw_runtime_lib.cc | 2 +- src/runtime/contrib/msc/tensorrt_runtime.cc | 7 +- src/runtime/contrib/mscclpp/allreduce.cu | 2 +- src/runtime/contrib/nnapi/nnapi_runtime.cc | 6 +- src/runtime/contrib/nvshmem/init.cc | 8 +- src/runtime/contrib/nvshmem/kv_transfer.cu | 6 +- .../contrib/nvshmem/memory_allocator.cc | 6 +- src/runtime/contrib/papi/papi.cc | 2 +- src/runtime/contrib/random/random.cc | 12 +- src/runtime/contrib/rocblas/rocblas.cc | 6 +- src/runtime/contrib/sort/sort.cc | 21 +- .../contrib/tensorrt/tensorrt_runtime.cc | 6 +- src/runtime/contrib/tflite/tflite_runtime.cc | 6 +- src/runtime/contrib/thrust/thrust.cu | 8 +- src/runtime/contrib/vllm/attention_kernels.cu | 8 +- src/runtime/contrib/vllm/cache_alloc.cc | 4 +- src/runtime/contrib/vllm/cache_kernels.cu | 8 +- src/runtime/cpu_device_api.cc | 4 +- src/runtime/cuda/cuda_device_api.cc | 21 +- src/runtime/cuda/cuda_module.cc | 8 +- src/runtime/cuda/l2_cache_flush.cc | 13 +- src/runtime/debug_compile.cc | 6 +- src/runtime/device_api.cc | 271 ++++++ src/runtime/disco/bcast_session.cc | 2 +- src/runtime/disco/builtin.cc | 48 +- src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc | 9 +- .../disco/cuda_ipc/custom_allreduce.cc | 4 +- src/runtime/disco/disco_worker.cc | 2 +- .../disco/distributed/socket_session.cc | 8 +- src/runtime/disco/loader.cc | 25 +- src/runtime/disco/nccl/nccl.cc | 35 +- src/runtime/disco/nccl/nccl_context.h | 4 +- src/runtime/disco/process_session.cc | 8 +- src/runtime/disco/protocol.h | 4 +- src/runtime/disco/session.cc | 24 +- src/runtime/disco/threaded_session.cc | 2 +- src/runtime/dso_library.cc | 4 +- src/runtime/file_utils.cc | 17 +- src/runtime/hexagon/hexagon_buffer.h | 2 +- src/runtime/hexagon/hexagon_common.cc | 6 +- src/runtime/hexagon/hexagon_device_api.cc | 31 +- src/runtime/hexagon/hexagon_module.cc | 2 +- src/runtime/hexagon/hexagon_thread_manager.h | 2 +- src/runtime/hexagon/hexagon_vtcm_pool.h | 2 +- src/runtime/hexagon/ops/conv2d.h | 2 +- src/runtime/hexagon/ops/conv2d_fp16_hvx.cc | 30 +- src/runtime/hexagon/ops/conv2d_quant_hvx.cc | 49 +- src/runtime/hexagon/rpc/android/session.cc | 4 +- src/runtime/hexagon/rpc/hexagon/rpc_server.cc | 8 +- .../hexagon/rpc/simulator/rpc_server.cc | 8 +- src/runtime/hexagon/rpc/simulator/session.cc | 4 +- src/runtime/library_module.cc | 2 +- src/runtime/library_module.h | 2 +- src/runtime/logging.cc | 2 +- src/runtime/memory/memory_manager.cc | 4 +- src/runtime/meta_data.h | 2 - src/runtime/metal/metal_common.h | 2 +- src/runtime/metal/metal_device_api.mm | 8 +- src/runtime/metal/metal_module.mm | 6 +- src/runtime/minrpc/minrpc_server.h | 2 +- src/runtime/module.cc | 28 +- src/runtime/ndarray.cc | 101 +-- src/runtime/object.cc | 92 -- src/runtime/object_internal.h | 96 --- src/runtime/opencl/opencl_common.h | 2 +- src/runtime/opencl/opencl_device_api.cc | 19 +- src/runtime/opencl/opencl_module.cc | 10 +- src/runtime/opencl/opencl_module_spirv.cc | 2 +- src/runtime/pack_args.h | 9 +- src/runtime/packed_func.cc | 2 +- src/runtime/profiling.cc | 32 +- src/runtime/regex.cc | 11 +- src/runtime/registry.cc | 266 ------ src/runtime/relax_vm/builtin.cc | 129 +-- .../relax_vm/cuda/cuda_graph_builtin.cc | 6 +- src/runtime/relax_vm/executable.cc | 8 +- src/runtime/relax_vm/hexagon/builtin.cc | 6 +- src/runtime/relax_vm/kv_state.cc | 52 +- src/runtime/relax_vm/kv_state.h | 2 +- src/runtime/relax_vm/lm_support.cc | 32 +- src/runtime/relax_vm/ndarray_cache_support.cc | 28 +- src/runtime/relax_vm/paged_kv_cache.cc | 4 +- src/runtime/relax_vm/rnn_state.cc | 2 +- src/runtime/relax_vm/vm.cc | 7 - src/runtime/rocm/rocm_device_api.cc | 17 +- src/runtime/rocm/rocm_module.cc | 10 +- src/runtime/rpc/rpc_device_api.cc | 4 +- src/runtime/rpc/rpc_endpoint.cc | 5 +- src/runtime/rpc/rpc_event_impl.cc | 4 +- src/runtime/rpc/rpc_local_session.cc | 4 +- src/runtime/rpc/rpc_module.cc | 32 +- src/runtime/rpc/rpc_pipe_impl.cc | 17 +- src/runtime/rpc/rpc_server_env.cc | 8 +- src/runtime/rpc/rpc_socket_impl.cc | 8 +- src/runtime/runtime_base.h | 67 -- src/runtime/spirv/spirv_shader.h | 2 +- src/runtime/static_library.cc | 6 +- src/runtime/system_library.cc | 17 +- src/runtime/thread_pool.cc | 8 +- src/runtime/threading_backend.cc | 5 +- src/runtime/vulkan/vulkan_common.h | 4 +- src/runtime/vulkan/vulkan_device_api.cc | 11 +- src/runtime/vulkan/vulkan_module.cc | 6 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 2 +- src/script/ir_builder/base.cc | 26 +- src/script/ir_builder/ir/frame.cc | 2 +- src/script/ir_builder/ir/ir.cc | 18 +- src/script/ir_builder/relax/distributed.cc | 2 +- src/script/ir_builder/relax/ir.cc | 33 +- src/script/ir_builder/tir/ir.cc | 180 ++-- src/script/printer/doc.cc | 76 +- .../printer/doc_printer/python_doc_printer.cc | 4 +- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/relax/type.cc | 2 +- src/support/ffi_testing.cc | 101 ++- src/support/libinfo.cc | 4 +- src/support/socket.h | 3 +- src/target/build_common.h | 2 +- src/target/codegen.cc | 14 +- src/target/datatype/myfloat/myfloat.cc | 2 +- src/target/datatype/posit/posit-wrapper.cc | 2 +- src/target/datatype/registry.cc | 10 +- src/target/datatype/registry.h | 4 +- src/target/intrin_rule.h | 2 +- src/target/llvm/codegen_aarch64.cc | 4 +- src/target/llvm/codegen_amdgpu.cc | 8 +- src/target/llvm/codegen_arm.cc | 4 +- src/target/llvm/codegen_cpu.cc | 8 +- src/target/llvm/codegen_hexagon.cc | 4 +- src/target/llvm/codegen_llvm.cc | 13 +- src/target/llvm/codegen_nvptx.cc | 4 +- src/target/llvm/codegen_x86_64.cc | 4 +- src/target/llvm/intrin_rule_llvm.h | 2 +- src/target/llvm/intrin_rule_nvptx.cc | 2 +- src/target/llvm/intrin_rule_rocm.cc | 2 +- src/target/llvm/llvm_module.cc | 59 +- src/target/opt/build_cuda_on.cc | 2 +- src/target/source/codegen_c.cc | 17 +- src/target/source/codegen_c_host.cc | 16 +- src/target/source/codegen_cuda.cc | 2 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_opencl.cc | 4 +- src/target/source/codegen_webgpu.cc | 2 +- src/target/source/source_module.cc | 8 +- src/target/spirv/build_vulkan.cc | 2 +- src/target/spirv/intrin_rule_spirv.cc | 2 +- src/target/tag.cc | 6 +- src/target/target.cc | 20 +- src/target/target_info.cc | 2 +- src/target/target_kind.cc | 11 +- src/target/virtual_device.cc | 2 +- src/te/operation/compute_op.cc | 4 +- src/te/operation/create_primfunc.cc | 21 +- src/te/operation/extern_op.cc | 4 +- src/te/operation/graph.cc | 6 +- src/te/operation/placeholder_op.cc | 4 +- src/te/operation/scan_op.cc | 4 +- src/te/tensor.cc | 14 +- .../analysis/block_access_region_detector.cc | 5 +- .../analysis/buffer_access_lca_detector.cc | 3 +- .../analysis/calculate_allocated_memory.cc | 6 +- src/tir/analysis/control_flow_graph.cc | 2 +- src/tir/analysis/deep_equal.cc | 4 +- src/tir/analysis/estimate_flops.cc | 23 +- src/tir/analysis/identify_memcpy.cc | 2 +- src/tir/analysis/is_pure_function.cc | 2 +- src/tir/analysis/oob_checker.cc | 2 +- src/tir/analysis/stmt_finding.cc | 2 +- src/tir/analysis/var_use_def_analysis.cc | 2 +- src/tir/analysis/verify_gpu_code.cc | 6 +- src/tir/analysis/verify_memory.cc | 6 +- src/tir/analysis/verify_ssa.cc | 6 +- src/tir/analysis/verify_well_formed.cc | 4 +- src/tir/ir/block_dependence_info.cc | 6 +- src/tir/ir/block_scope.cc | 19 +- src/tir/ir/buffer.cc | 17 +- src/tir/ir/data_layout.cc | 34 +- src/tir/ir/expr.cc | 83 +- src/tir/ir/function.cc | 10 +- src/tir/ir/index_map.cc | 19 +- src/tir/ir/script/script_complete.cc | 2 +- src/tir/ir/script/script_complete.h | 2 +- src/tir/ir/specialize.cc | 4 +- src/tir/ir/stmt.cc | 47 +- src/tir/ir/stmt_functor.cc | 10 +- src/tir/ir/transform.cc | 4 +- src/tir/op/builtin.cc | 2 +- src/tir/op/op.cc | 72 +- src/tir/schedule/analysis/analysis.cc | 27 +- src/tir/schedule/analysis/layout.cc | 2 +- src/tir/schedule/instruction.cc | 4 +- .../schedule/primitive/decompose_padding.cc | 2 +- src/tir/schedule/primitive/reduction.cc | 2 +- src/tir/schedule/schedule.cc | 136 +-- src/tir/schedule/state.cc | 10 +- src/tir/schedule/trace.cc | 20 +- src/tir/schedule/transform.cc | 4 +- src/tir/transforms/annotate_device_regions.cc | 5 +- src/tir/transforms/bind_params.cc | 2 +- src/tir/transforms/bound_checker.cc | 4 +- src/tir/transforms/combine_context_call.cc | 4 +- src/tir/transforms/common_subexpr_elim.cc | 2 +- src/tir/transforms/compact_buffer_region.cc | 2 +- .../transforms/convert_blocks_to_opaque.cc | 3 +- .../transforms/convert_for_loops_serial.cc | 2 +- src/tir/transforms/decorate_device_scope.cc | 4 +- src/tir/transforms/default_gpu_schedule.cc | 2 +- src/tir/transforms/extract_constants.cc | 4 +- src/tir/transforms/flatten_buffer.cc | 2 +- .../transforms/force_narrow_index_to_i32.cc | 2 +- src/tir/transforms/hoist_expression.cc | 8 +- src/tir/transforms/inject_double_buffer.cc | 4 +- src/tir/transforms/inject_permuted_layout.cc | 2 +- src/tir/transforms/inject_ptx_async_copy.cc | 2 +- src/tir/transforms/inject_ptx_ldg32.cc | 4 +- src/tir/transforms/inject_rolling_buffer.cc | 4 +- .../transforms/inject_software_pipeline.cc | 3 +- src/tir/transforms/inject_virtual_thread.cc | 4 +- .../transforms/inline_private_functions.cc | 5 +- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/lift_thread_binding.cc | 2 +- src/tir/transforms/loop_partition.cc | 4 +- src/tir/transforms/lower_async_dma.cc | 2 +- .../lower_cross_thread_reduction.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 4 +- .../transforms/lower_device_kernel_launch.cc | 4 +- .../lower_device_storage_access_info.cc | 4 +- src/tir/transforms/lower_init_block.cc | 2 +- src/tir/transforms/lower_intrin.cc | 4 +- src/tir/transforms/lower_match_buffer.cc | 2 +- src/tir/transforms/lower_opaque_block.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 4 +- src/tir/transforms/lower_tvm_builtin.cc | 4 +- src/tir/transforms/lower_vtcm_alloc.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 4 +- src/tir/transforms/make_packed_api.cc | 6 +- src/tir/transforms/make_unpacked_api.cc | 4 +- .../manifest_shared_memory_local_stage.cc | 2 +- .../transforms/memhammer_lower_auto_copy.cc | 4 +- src/tir/transforms/memhammer_rewrite_rule.h | 2 +- .../merge_shared_memory_allocations.cc | 4 +- src/tir/transforms/narrow_datatype.cc | 4 +- .../plan_update_buffer_allocation_location.cc | 2 +- src/tir/transforms/primfunc_utils.cc | 6 +- src/tir/transforms/profile_instrumentation.cc | 2 +- .../reduce_branching_through_overcompute.cc | 2 +- src/tir/transforms/remap_thread_axis.cc | 4 +- src/tir/transforms/remove_assume.cc | 4 +- src/tir/transforms/remove_no_op.cc | 4 +- src/tir/transforms/remove_store_undef.cc | 4 +- .../remove_weight_layout_rewrite_block.cc | 2 +- src/tir/transforms/renew_defs.cc | 2 +- .../transforms/renormalize_split_pattern.cc | 4 +- src/tir/transforms/rewrite_unsafe_select.cc | 4 +- src/tir/transforms/simplify.cc | 4 +- src/tir/transforms/skip_assert.cc | 4 +- src/tir/transforms/split_host_device.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 6 +- .../transforms/tensorcore_infer_fragment.cc | 4 +- src/tir/transforms/thread_storage_sync.cc | 4 +- .../transforms/transform_mma_buffer_layout.cc | 2 +- src/tir/transforms/unify_thread_binding.cc | 2 +- src/tir/transforms/unroll_loop.cc | 4 +- .../transforms/unsupported_dtype_legalize.cc | 10 +- .../using_assume_to_reduce_branches.cc | 2 +- src/tir/transforms/vectorize_loop.cc | 4 +- src/topi/broadcast.cc | 35 +- src/topi/einsum.cc | 2 +- src/topi/elemwise.cc | 75 +- src/topi/nn.cc | 116 +-- src/topi/reduction.cc | 25 +- src/topi/transform.cc | 169 ++-- src/topi/utils.cc | 8 +- src/topi/vision.cc | 9 +- tests/cpp-runtime/hexagon/run_all_tests.cc | 4 +- tests/cpp-runtime/hexagon/run_unit_tests.cc | 4 +- tests/cpp-runtime/opencl/texture_copy_test.cc | 2 +- tests/cpp/llvm_codegen_registry_test.cc | 2 +- .../python/contrib/test_hexagon/README_RPC.md | 4 +- tests/python/runtime/test_runtime_rpc.py | 6 +- version.py | 4 +- web/emcc/tvmjs_support.cc | 4 +- web/emcc/wasm_runtime.cc | 5 +- 625 files changed, 3211 insertions(+), 5383 deletions(-) create mode 100644 include/tvm/runtime/base.h delete mode 100644 include/tvm/runtime/c_runtime_api.h delete mode 100644 include/tvm/runtime/registry.h delete mode 100644 src/runtime/c_runtime_api.cc delete mode 100644 src/runtime/container.cc create mode 100644 src/runtime/device_api.cc delete mode 100644 src/runtime/object.cc delete mode 100644 src/runtime/object_internal.h delete mode 100644 src/runtime/registry.cc delete mode 100644 src/runtime/runtime_base.h diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index bbccc75af117..3e07e778d78f 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit bbccc75af117473f6de81905bd3314775f41636e +Subproject commit 3e07e778d78f0fcd047533c1fdaed571a68a396f diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 26085bc366f4..5255d3f4b10a 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -42,9 +42,8 @@ #include "../ffi/src/ffi/object.cc" #include "../ffi/src/ffi/testing.cc" #include "../ffi/src/ffi/traceback.cc" -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/container.cc" #include "../src/runtime/cpu_device_api.cc" +#include "../src/runtime/device_api.cc" #include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" #include "../src/runtime/library_module.cc" @@ -53,7 +52,6 @@ #include "../src/runtime/minrpc/minrpc_logger.cc" #include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" #include "../src/runtime/profiling.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/rpc/rpc_channel.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 88ad99e47af2..e5a5154acbf2 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,7 +20,9 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ -#include +#include +#include +#include #include #ifndef _WIN32 diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index dbb0a62d2c5d..a5d3f6957c33 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -24,7 +24,7 @@ #ifndef TVM_APPS_CPP_RPC_ENV_H_ #define TVM_APPS_CPP_RPC_ENV_H_ -#include +#include #include diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index c4ee4d35450f..2f74dd309f42 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -21,7 +21,7 @@ * \file rpc_server.cc * \brief RPC Server implementation. */ -#include +#include #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) #include #include @@ -398,6 +398,6 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); +TVM_FFI_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index e4565d095b2e..9bb61065c58a 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -26,7 +26,7 @@ #include -#include "tvm/runtime/c_runtime_api.h" +#include "tvm/runtime/base.h" namespace tvm { namespace runtime { diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index f4fc9fb365a8..aebde97a51f4 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -19,9 +19,9 @@ #include "launcher_core.h" +#include #include #include -#include #include #include diff --git a/apps/ios_rpc/tvmrpc/RPCServer.mm b/apps/ios_rpc/tvmrpc/RPCServer.mm index 3dc2fb0c192a..4717d7103254 100644 --- a/apps/ios_rpc/tvmrpc/RPCServer.mm +++ b/apps/ios_rpc/tvmrpc/RPCServer.mm @@ -23,8 +23,8 @@ #import "RPCServer.h" +#include #include -#include #include #include diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 243e4819d025..8d0ae7368d8a 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -23,7 +23,7 @@ #import -#include +#include #include "RPCArgs.h" @@ -51,14 +51,14 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.workpath") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { static const std::string base_ = NSTemporaryDirectory().UTF8String; const auto path = args[0].cast(); *rv = base_ + "/" + path; }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.load_module") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); std::string fmt = GetFileFormat(name, ""); @@ -109,7 +109,7 @@ void Init(const std::string& name) { }; // Add UnsignedDSOLoader plugin in global registry -TVM_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto n = make_object(); n->Init(args[0]); diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index e39468f0bf78..09867f88fa36 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -153,18 +153,18 @@ then be registered with the following steps. #. Register the function to the tvm registry:: - TVM_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); + TVM_FFI_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); -.. _c_runtime_api.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h +.. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h #. Add an entry for the new DeviceAPI to the ``TVMDeviceExtType`` enum - in `c_runtime_api.h`_. The value should be an unused value greater + in `base.h`_. The value should be an unused value greater than ``DLDeviceType::kDLExtDev``, but less than ``DeviceAPIManager::kMaxDeviceAPI``. #. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the enum value to a string representation. This string representation - should match the name given to ``TVM_REGISTER_GLOBAL``. + should match the name given to ``TVM_FFI_REGISTER_GLOBAL``. #. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. @@ -225,7 +225,7 @@ the same name as was used in the ``TVM_REGISTER_TARGET_KIND`` definition above. :: tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target); - TVM_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); + TVM_FFI_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); The code generator takes two arguments. The first is the ``IRModule`` to compile, and the second is the ``Target`` that describes the device diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 85e9f45a5fba..bf7b52229d13 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -376,7 +376,7 @@ Python when needed. return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } - TVM_REGISTER_GLOBAL("relax.transform.FoldConstant") + TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant") .set_body_typed(FoldConstant); } // namespace transform diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index f797039ee386..55c523cb4cc4 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -80,7 +80,7 @@ The following example registers PackedFunc in C++ and calls from python. .. code:: c // register a global packed function in c++ - TVM_REGISTER_GLOBAL("myadd") + TVM_FFI_REGISTER_GLOBAL("myadd") .set_body_packed(MyAdd); .. code:: python @@ -110,7 +110,7 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c - TVM_REGISTER_GLOBAL("callhello") + TVM_FFI_REGISTER_GLOBAL("callhello") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { PackedFunc f = args[0]; f("hello world"); @@ -134,7 +134,7 @@ which allows us to embed the PackedFunc into any languages. Besides python, so f `java`_ and `javascript`_. This philosophy of embedded API is very like Lua, except that we don't have a new language but use C++. -.. _minimum C API: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h +.. _minimum C API: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h .. _java: https://github.com/apache/tvm/tree/main/jvm .. _javascript: https://github.com/apache/tvm/tree/main/web @@ -282,7 +282,7 @@ Each argument in PackedFunc contains a union value `TVMValue`_ and a type code. This design allows the dynamically typed language to convert to the corresponding type directly, and statically typed language to do runtime type checking during conversion. -.. _TVMValue: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h#L135 +.. _TVMValue: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h#L135 The relevant files are diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 131f2e73e08a..df8265d0b9c7 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -27,6 +27,13 @@ #include #include +// Macros to do weak linking +#ifdef _MSC_VER +#define TVM_FFI_WEAK __declspec(selectany) +#else +#define TVM_FFI_WEAK __attribute__((weak)) +#endif + #if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) #include #define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 5162df6d830d..753d4f50f19c 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -787,7 +787,7 @@ class Function::Registry { * .set_body_typed(multiply); // will have type int(int, int) * * // will have type int(int, int) - * TVM_REGISTER_GLOBAL("sub") + * TVM_FFI_REGISTER_GLOBAL("sub") * .set_body_typed([](int a, int b) -> int { return a - b; }); * * \endcode diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 8eaa62a98120..9c758a52b384 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -25,13 +25,13 @@ #ifndef TVM_IR_OP_H_ #define TVM_IR_OP_H_ +#include #include #include #include #include #include #include -#include #include #include diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 7b79a2c89455..2752d9951a3f 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -23,10 +23,10 @@ #ifndef TVM_IR_SOURCE_MAP_H_ #define TVM_IR_SOURCE_MAP_H_ +#include #include #include #include -#include #include #include diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index aebff220abbc..8a9e763fecbf 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -39,7 +39,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index fb15ff4f81e5..ab197078f317 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h index c99d0f7f73fb..5a8e098cfd6e 100644 --- a/include/tvm/node/serialization.h +++ b/include/tvm/node/serialization.h @@ -24,7 +24,7 @@ #ifndef TVM_NODE_SERIALIZATION_H_ #define TVM_NODE_SERIALIZATION_H_ -#include +#include #include #include diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 2cee3bca631b..81d6d4eb379e 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -23,11 +23,11 @@ #ifndef TVM_RELAX_EXEC_BUILDER_H_ #define TVM_RELAX_EXEC_BUILDER_H_ +#include #include #include #include #include -#include #include #include diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 7e4149fe5548..bd75197bfe21 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -24,10 +24,10 @@ #ifndef TVM_RELAX_TYPE_H_ #define TVM_RELAX_TYPE_H_ +#include #include #include #include -#include #include #include diff --git a/include/tvm/runtime/base.h b/include/tvm/runtime/base.h new file mode 100644 index 000000000000..c704decb63e9 --- /dev/null +++ b/include/tvm/runtime/base.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/runtime/base.h + * \brief base macros + */ +#ifndef TVM_RUNTIME_BASE_H_ +#define TVM_RUNTIME_BASE_H_ + +// TVM runtime fully relies on TVM FFI C API +// we will avoid defining extra C APIs here +#include + +// TVM version +#define TVM_VERSION "0.21.dev0" + +// define extra macros for TVM DLL exprt +#ifdef __EMSCRIPTEN__ +#include +#define TVM_DLL EMSCRIPTEN_KEEPALIVE +#endif + +// helper macro to suppress unused warning +#if defined(__GNUC__) +#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_ATTRIBUTE_UNUSED +#endif + +#ifndef TVM_DLL +#ifdef _WIN32 +#ifdef TVM_EXPORTS +#define TVM_DLL __declspec(dllexport) +#else +#define TVM_DLL __declspec(dllimport) +#endif +#else +#define TVM_DLL __attribute__((visibility("default"))) +#endif +#endif + +#endif // TVM_RUNTIME_BASE_H_ diff --git a/include/tvm/runtime/builtin_fp16.h b/include/tvm/runtime/builtin_fp16.h index 5b54583da4ff..3ea670017d3d 100644 --- a/include/tvm/runtime/builtin_fp16.h +++ b/include/tvm/runtime/builtin_fp16.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_BUILTIN_FP16_H_ #define TVM_RUNTIME_BUILTIN_FP16_H_ -#include +#include #include diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index eb8d7270b137..0d84b55fe318 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -28,28 +28,12 @@ #ifndef TVM_RUNTIME_C_BACKEND_API_H_ #define TVM_RUNTIME_C_BACKEND_API_H_ -#include +#include #ifdef __cplusplus extern "C" { #endif -/*! - * \brief Signature for backend functions exported as DLL. - * - * \param args The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * \param out_ret_value The output value of the return value. - * \param out_ret_tcode The output type code of the return value. - * \param resource_handle Pointer to associated resource. - * - * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. - */ -typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, - TVMValue* out_ret_value, int* out_ret_tcode, - void* resource_handle); - /*! * \brief Backend function for modules to get function * from its environment mod_node (its imports and global function). @@ -60,7 +44,8 @@ typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_ar * \param out The result function. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, + TVMFFIObjectHandle* out); /*! * \brief Backend function to register system-wide library symbol. @@ -100,19 +85,6 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t */ TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); -/*! - * \brief Backend function to register execution environment(e.g. python) - * specific C APIs. - * - * \note We only register the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). In most cases we should use the ffi::Function FFI. - * - * \param name The name of the symbol - * \param ptr The symbol address. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_DLL int TVMBackendRegisterEnvCAPI(const char* name, void* ptr); - /*! * \brief Environment for TVM parallel task. */ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h deleted file mode 100644 index b802dbc22839..000000000000 --- a/include/tvm/runtime/c_runtime_api.h +++ /dev/null @@ -1,732 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/runtime/c_runtime_api.h - * \brief TVM runtime library. - * - * The philosophy of TVM project is to customize the compilation - * stage to generate code that can used by other projects transparently. - * So this is a minimum runtime code gluing, and some limited - * memory management code to enable quick testing. - * - * The runtime API is independent from TVM compilation stack and can - * be linked via libtvm_runtime. - * - * The common flow is: - * - Use TVMFuncListGlobalNames to get global function name - * - Use TVMFuncCall to call these functions. - * - * Possible return values of the API functions: - * * 0: success - * * -1: the error can be retrieved through TVMGetLastError. - * * -2: a frontend error occurred and recorded in the frontend. - */ -#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ -#define TVM_RUNTIME_C_RUNTIME_API_H_ - -// Macros to do weak linking -#ifdef _MSC_VER -#define TVM_WEAK __declspec(selectany) -#else -#define TVM_WEAK __attribute__((weak)) -#endif - -#ifdef __EMSCRIPTEN__ -#include -#define TVM_DLL EMSCRIPTEN_KEEPALIVE -#endif - -// helper macro to suppress unused warning -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -#ifndef TVM_DLL -#ifdef _WIN32 -#ifdef TVM_EXPORTS -#define TVM_DLL __declspec(dllexport) -#else -#define TVM_DLL __declspec(dllimport) -#endif -#else -#define TVM_DLL __attribute__((visibility("default"))) -#endif -#endif - -// TVM version -#define TVM_VERSION "0.21.dev0" - -// TVM Runtime is DLPack compatible. -#include - -#ifdef __cplusplus -extern "C" { -#endif -#include -#include -#include - -/*! \brief type of array index. */ -typedef int64_t tvm_index_t; - -/*! \brief Extension device types in TVM - * - * Additional enumerators to supplement those provided by - * DLPack's `DLDeviceType` enumeration. - * - * MAINTAINERS NOTE #1: We need to ensure that the two devices - * are identified by the same integer. - * Currently this requires manual verification. - * Discussed here: https://github.com/dmlc/dlpack/issues/111 - * As of DLPack v0.7, the highest-valued enumerator in - * `DLDeviceType` is kDLHexagon = 16. - * - * MAINTAINERS NOTE #2: As of DLPack v0.7, the definition for - * `DLDeviceType` specifies an underlying storage type of - * `int32_t`. That guarantees a variable of type - * `DLDeviceType` is capable of holding any integers provided - * by *either* of these enumerations. - * - * However, the `int32_t` specification only applies when the - * header file is compiled as C++, and this header file is also - * meant to work as C code. So the unspecified storage type - * could be a latent bug when compiled as C. - */ -#ifdef __cplusplus -typedef enum : int32_t { -#else -typedef enum { -#endif - // To help avoid accidental conflicts between `DLDeviceType` - // and this enumeration, start numbering the new enumerators - // a little higher than (currently) seems necessary. - TVMDeviceExtType_End = 36, // sentinel value -} TVMDeviceExtType; - -#ifdef __cplusplus -// Some other parts of TVM hardcode the integer identifier for -// some DLPack / TVM devices, rather then using the symbolic -// enumerator. E.g., `2` rather than `kDLCUDA`. -// These asserts should alert us when that mapping breaks. -#define TVM_HARCODED_INTEGER_CHANGED_MSG \ - "Change in compile-time integer. Make sure hardcoded uses of this integer throughout TVM are " \ - "updated." -static_assert(kDLCPU == 1, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDA == 2, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDAHost == 3, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLOpenCL == 4, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLVulkan == 7, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLMetal == 8, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLVPI == 9, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLROCM == 10, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLROCMHost == 11, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLExtDev == 12, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDAManaged == 13, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLOneAPI == 14, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLWebGPU == 15, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLHexagon == 16, TVM_HARCODED_INTEGER_CHANGED_MSG); - -#undef TVM_HARCODED_INTEGER_CHANGED_MSG -#endif - -/*! - * \brief The type code in used and only used in TVM FFI for argument passing. - * - * DLPack consistency: - * 1) kTVMArgInt is compatible with kDLInt - * 2) kTVMArgFloat is compatible with kDLFloat - * 3) kDLUInt is not in ArgTypeCode, but has a spared slot - * - * Downstream consistency: - * The kDLInt, kDLUInt, kDLFloat are kept consistent with the original ArgType code - * - * It is only used in argument passing, and should not be confused with - * DataType::TypeCode, which is DLPack-compatible. - * - * \sa tvm::runtime::DataType::TypeCode - */ -typedef enum { - kTVMArgInt = kDLInt, - kTVMArgFloat = kDLFloat, - kTVMOpaqueHandle = 3U, - kTVMNullptr = 4U, - kTVMDataType = 5U, - kDLDevice = 6U, - kTVMDLTensorHandle = 7U, - kTVMObjectHandle = 8U, - kTVMModuleHandle = 9U, - kTVMPackedFuncHandle = 10U, - kTVMStr = 11U, - kTVMBytes = 12U, - kTVMNDArrayHandle = 13U, - kTVMObjectRValueRefArg = 14U, - kTVMArgBool = 15U, - // Extension codes for other frameworks to integrate TVM ffi::Function. - // To make sure each framework's id do not conflict, use first and - // last sections to mark ranges. - // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 16U, - kTVMNNVMFirst = 16U, - kTVMNNVMLast = 20U, - // The following section of code is used for non-reserved types. - kTVMExtReserveEnd = 64U, - kTVMExtEnd = 128U, -} TVMArgTypeCode; - -/*! \brief the array handle */ -typedef DLTensor* TVMArrayHandle; - -/*! - * \brief Union type of values - * being passed through API and function calls. - */ -typedef union { - int64_t v_int64; - double v_float64; - void* v_handle; - const char* v_str; - DLDataType v_type; - DLDevice v_device; -} TVMValue; - -/*! - * \brief Byte array type used to pass in byte array - * When kTVMBytes is used as data type. - */ -typedef struct { - const char* data; - size_t size; -} TVMByteArray; - -/*! \brief Handle to TVM runtime modules. */ -typedef void* TVMModuleHandle; -/*! \brief Handle to packed function handle. */ -typedef void* TVMFunctionHandle; -/*! \brief Handle to hold return value. */ -typedef void* TVMRetValueHandle; -/*! - * \brief The stream that is specific to device - * can be NULL, which indicates the default one. - */ -typedef void* TVMStreamHandle; -/*! \brief Handle to Object. */ -typedef void* TVMObjectHandle; - -/*! - * \brief Used for implementing C API function. - * Set last error message before return. - * \param msg The error message to be set. - */ -TVM_DLL void TVMAPISetLastError(const char* msg); - -/*! - * \brief Used for implementing C API function. - * Set last exception before return. - * \param py_object The python exception to be set - */ -TVM_DLL void TVMAPISetLastPythonError(void* py_object); - -/*! \brief Return the previous python error, if any. - * - * Used to propagate the original Python exception to a python - * try/except, when there are C++ stack frames between the location thro - * - * \return The previous argument passed during the most recent call to - * TVMAPISetLastPythonError. If TVMAPISetLastPythonError has not - * been called, or if TVMDropLastPythonError has been called since - * the most recent to TVMAPISetLastPythonError, returns nullptr. - */ -TVM_DLL void* TVMGetLastPythonError(); - -/*! - * \brief return str message of the last error - * all function in this file will return 0 when success - * and nonzero when an error occurred, - * TVMGetLastError can be called to retrieve the error - * - * this function is threadsafe and can be called by different thread - * \return error info - */ -TVM_DLL const char* TVMGetLastError(void); - -/*! - * \brief Return the backtrace of the most recent error - * - * Returns the backtrace of the most recent error, if an error exists, - * and the error contains a backtrace. If no error exists or the - * error does not contain a backtrace, returns nullptr. - * - * \return The backtrace of the most recent error - */ -TVM_DLL const char* TVMGetLastBacktrace(); - -/*! - * \brief Remove the propagated python error, if any - * - * Removes the TVM-held reference to a thrown python exception object. - * Because these objects contain references to the stack frames from - * which the exception was thrown, maintaining a reference to an - * exception object prevents any local python variables from being - * garbage-collected. After retrieving the object using - * TVMGetLastPythonError, the Python FFI interface uses this method to - * clear the TVM-held reference to the exception, to allow garbage - * collection to continue. - */ -TVM_DLL void TVMDropLastPythonError(); - -/*! \brief Re-throw the most recent error. - * - * If an error was previously set using TVMAPISetLastError or - * TVMAPISetLastPythonError, re-throw the error. This is similar to - * `LOG(FATAL) << TVMGetLastError()`, but includes handling to - * propagate a python exception across C++ stack frames, or to append - * a stack trace to an error message. - */ -TVM_DLL void TVMThrowLastError(); - -/*! - * \brief Load module from file. - * \param file_name The file name to load the module from. - * \param format The format of the module. - * \param out The result module - * - * \return 0 when success, nonzero when failure happens - * \note The resulting module do not contain import relation. - * It can be reconstructed by TVMModImport. - */ -TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out); - -/*! - * \brief Add dep to mod's dependency. - * This allows functions in this module to use modules. - * - * \param mod The module handle. - * \param dep The dependent module to be imported. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep); - -/*! - * \brief Get function from the module. - * \param mod The module handle. - * \param func_name The name of the function. - * \param query_imports Whether to query imported modules - * \param out The result function, can be NULL if it is not available. - * \return 0 when no error is thrown, nonzero when failure happens - */ -TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, - TVMFunctionHandle* out); - -/*! - * \brief Free the Module - * \param mod The module to be freed. - * - * \note This may not free up the module's resources. - * If there is active TVMFunctionHandle uses the module - * Or if this module is imported by another active module. - * - * The all functions remains valid until TVMFuncFree is called. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMModFree(TVMModuleHandle mod); - -/*! - * \brief Free the function when it is no longer needed. - * \param func The function handle - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncFree(TVMFunctionHandle func); - -/*! - * \brief Call a Packed TVM Function. - * - * \param func node handle of the function. - * \param arg_values The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * - * \param ret_val The return value. - * \param ret_type_code the type code of return value. - * - * \return 0 when success, nonzero when failure happens - * \note TVM calls always exchanges with type bits=64, lanes=1 - * - * \note API calls always exchanges with type bits=64, lanes=1 - * If API call returns container handles (e.g. FunctionHandle) - * these handles should be managed by the front-end. - * The front-end need to call free function (e.g. TVMFuncFree) - * to free these handles. - */ -TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, - TVMValue* ret_val, int* ret_type_code); - -/*! - * \brief Set the return value of TVMPackedCFunc. - * - * This function is called by TVMPackedCFunc to set the return value. - * When this function is not called, the function returns null by default. - * - * \param ret The return value handle, pass by ret in TVMPackedCFunc - * \param value The value to be returned. - * \param type_code The type of the value to be returned. - * \param num_ret Number of return values, for now only 1 is supported. - */ -TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret); - -/*! - * \brief Inplace translate callback argument value to return value. - * This is only needed for non-POD arguments. - * - * \param value The value to be translated. - * \param code The type code to be translated. - * \note This function will do a shallow copy when necessary. - * - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); - -/*! - * \brief C type of packed function. - * - * \param args The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * \param ret The return value handle. - * \param resource_handle The handle additional resouce handle from front-end. - * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. - * \sa TVMCFuncSetReturn - */ -typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle); - -/*! - * \brief C callback to free the resource handle in C packed function. - * \param resource_handle The handle additional resouce handle from front-end. - */ -typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); - -/*! - * \brief Signature for extension function declarer. - * - * TVM call this function to get the extension functions - * The declarer will call register_func to register function and their name. - * - * \param register_func_handle The register function - * \return 0 if success, -1 if failure happens - */ -typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); - -/*! - * \brief Wrap a TVMPackedCFunc to become a FunctionHandle. - * - * The resource_handle will be managed by TVM API, until the function is no longer used. - * - * \param func The packed C function. - * \param resource_handle The resource handle from front-end, can be NULL. - * \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL - * \param out the result function handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, - TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out); - -/*! - * \brief Register the function to runtime's global table. - * - * The registered function then can be pulled by the backend by the name. - * - * \param name The name of the function. - * \param f The function to be registered. - * \param override Whether allow override already registered function. - */ -TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override); - -/*! - * \brief Get a global function. - * - * \param name The name of the function. - * \param out the result function pointer, NULL if it does not exist. - * - * \note The function handle of global function is managed by TVM runtime, - * So TVMFuncFree is should not be called when it get deleted. - */ -TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); - -/*! - * \brief List all the globally registered function name - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); - -/*! - * \brief Remove a global function. - * \param name The name of the function. - */ -TVM_DLL int TVMFuncRemoveGlobal(const char* name); - -// Array related apis for quick proptyping -/*! - * \brief Allocate a nd-array's memory, - * including space of shape, of given spec. - * - * \param shape The shape of the array, the data content will be copied to out - * \param ndim The number of dimension of the array. - * \param dtype_code The type code of the dtype - * \param dtype_bits The number of bits of dtype - * \param dtype_lanes The number of lanes in the dtype. - * \param device_type The device type. - * \param device_id The device id. - * \param out The output handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); - -/*! - * \brief Free the TVM Array. - * \param handle The array handle to be freed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayFree(TVMArrayHandle handle); - -/*! - * \brief Copy array data from CPU byte array. - * \param handle The array handle. - * \param data the data pointer - * \param nbytes The number of bytes to copy. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes); - -/*! - * \brief Copy array data to CPU byte array. - * \param handle The array handle. - * \param data the data pointer - * \param nbytes The number of bytes to copy. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes); - -/*! - * \brief Copy the array, both from and to must be valid during the copy. - * \param from The array to be copied from. - * \param to The target space. - * \param stream The stream where the copy happens, can be NULL. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); - -/*! - * \brief Produce an array from the DLManagedTensor that shares data memory - * with the DLManagedTensor. - * \param from The source DLManagedTensor. - * \param out The output array handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with - * the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out); - -/*! - * \brief Delete (free) a DLManagedTensor's data. - * \param dltensor Pointer to the DLManagedTensor. - */ -TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor); - -/*! - * \brief Create a new runtime stream. - * - * \param device_type The device type. - * \param device_id The device id. - * \param out The new stream handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out); - -/*! - * \brief Free a created stream handle. - * - * \param device_type The device type. - * \param device_id The device id. - * \param stream The stream to be freed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream); - -/*! - * \brief Set the runtime stream of current thread to be stream. - * The subsequent calls to the same device_type - * will use the setted stream handle. - * The specific type of stream is runtime device dependent. - * - * \param device_type The device type. - * \param device_id The device id. - * \param handle The stream handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle); - -/*! - * \brief Wait until all computations on stream completes. - * - * \param device_type The device type. - * \param device_id The device id. - * \param stream The stream to be synchronized. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); - -/*! - * \brief Synchronize two streams of execution. - * - * \param device_type The device type. - * \param device_id The device id. - * \param src The source stream to synchronize. - * \param dst The destination stream to synchronize. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, - TVMStreamHandle dst); - -/*! - * \brief Get the type_index from an object. - * - * \param obj The object handle. - * \param out_tindex the output type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); - -/*! - * \brief Convert type index to type key. - * \param tindex The type index. - * \param out_type_key The output type key. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); - -/*! - * \brief Increase the reference count of an object. - * - * \param obj The object handle. - * \note Internally we increase the reference counter of the object. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectRetain(TVMObjectHandle obj); - -/*! - * \brief Free the object. - * - * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectFree(TVMObjectHandle obj); - -/*! - * \brief Free a TVMByteArray returned from TVMFuncCall, and associated memory. - * \param arr The TVMByteArray instance. - * \return 0 on success, -1 on failure. - */ -TVM_DLL int TVMByteArrayFree(TVMByteArray* arr); - -/*! - * \brief Allocate a data space on device. - * \param dev The device to perform operation. - * \param nbytes The number of bytes in memory. - * \param alignment The alignment of the memory. - * \param type_hint The type of elements. Only needed by certain backends such - * as nbytes & alignment are sufficient for most backends. - * \param out_data The allocated device pointer. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, - DLDataType type_hint, void** out_data); - -/*! - * \brief Allocate a data space on device with special memory scope. - * \note The memory could use a special multi-dimensional memory layout. - * That is why we pass shape and dtype instead of raw number of bytes. - * \param dev The device to perform operation. - * \param ndim The number of dimension of the tensor. - * \param shape The shape of the tensor. - * \param dtype The type of elements. - * \param mem_scope The memory scope of the tensor, - * can be nullptr, which indicate the default global DRAM - * \param out_data The allocated device pointer. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, - DLDataType dtype, const char* mem_scope, - void** out_data); - -/*! - * \brief Free a data space on device. - * \param dev The device to perform operation. - * \param ptr The data space. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr); - -/*! - * \brief Copy data from one place to another. - * \note This API is designed to support special memory with shape dependent layout. - * We pass in DLTensor* with shape information to support these cases. - * \param from The source tensor. - * \param to The target tensor. - * \param stream Optional stream object. - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream); - -/*! - * \brief Check that an object is derived from another. - * \param child_type_index The type index of the derived type. - * \param parent_type_index The type index of the parent type. - * \param is_derived A boolean representing whether this predicate holds. - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, - int* is_derived); - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index b5b6ad19d228..d5f3c6ee3d7f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,8 +24,9 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ +#include #include -#include +#include #include #include @@ -35,6 +36,8 @@ namespace tvm { namespace runtime { +using tvm_index_t = ffi::Shape::index_type; + /*! * \brief Runtime primitive data type. * diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index a4b53eb79734..7366b9895d5e 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -26,10 +26,15 @@ #include #include -#include +#include #include #include +/*! + * \brief The stream that is specific to device + * can be NULL, which indicates the default one. + */ +typedef void* TVMStreamHandle; namespace tvm { @@ -37,6 +42,41 @@ namespace tvm { using Device = DLDevice; namespace runtime { + +/*! \brief Extension device types in TVM + * + * Additional enumerators to supplement those provided by + * DLPack's `DLDeviceType` enumeration. + * + * MAINTAINERS NOTE #1: We need to ensure that the two devices + * are identified by the same integer. + * Currently this requires manual verification. + * Discussed here: https://github.com/dmlc/dlpack/issues/111 + * As of DLPack v0.7, the highest-valued enumerator in + * `DLDeviceType` is kDLHexagon = 16. + * + * MAINTAINERS NOTE #2: As of DLPack v0.7, the definition for + * `DLDeviceType` specifies an underlying storage type of + * `int32_t`. That guarantees a variable of type + * `DLDeviceType` is capable of holding any integers provided + * by *either* of these enumerations. + * + * However, the `int32_t` specification only applies when the + * header file is compiled as C++, and this header file is also + * meant to work as C code. So the unspecified storage type + * could be a latent bug when compiled as C. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + // To help avoid accidental conflicts between `DLDeviceType` + // and this enumeration, start numbering the new enumerators + // a little higher than (currently) seems necessary. + TVMDeviceExtType_End = 36, // sentinel value +} TVMDeviceExtType; + /*! * \brief the query type into GetAttr */ diff --git a/include/tvm/runtime/disco/cuda_ipc_memory.h b/include/tvm/runtime/disco/cuda_ipc_memory.h index 120e6a543179..ea272052626f 100644 --- a/include/tvm/runtime/disco/cuda_ipc_memory.h +++ b/include/tvm/runtime/disco/cuda_ipc_memory.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_ #define TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_ -#include +#include #include #include diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 807c9dbf30bc..da715848e09a 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index 7b19cdb1cea4..f103c6f30ac8 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_ #define TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_ -#include +#include #include #include diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index a4b192db8c1c..c02e312b71a0 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 2b36306e5963..6eebe49ff135 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include #include @@ -166,20 +166,15 @@ class NDArray : public tvm::ffi::NDArray { TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); - struct Internal; - - protected: - /*! - * \brief DecRef resource managed by an FFI array handle. - * \param handle The array handle. - */ - inline static void FFIDecRef(TVMArrayHandle handle); /*! - * \brief Get FFI Array handle from ndarray. - * \param nd The object with ndarray type. - * \return The result array handle. + * \brief Function to copy data from one array to a byte buffer. + * \param from The source array. + * \param to The target byte buffer. + * \param nbytes The size of the data buffer. + * \param stream The stream used in copy. */ - inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd); + TVM_DLL static void CopyToBytes(const DLTensor* from, void* to, size_t nbytes, + TVMStreamHandle stream = nullptr); }; /*! @@ -211,28 +206,6 @@ inline void NDArray::CopyTo(const NDArray& other) const { CopyFromTo(get_mutable(), other.get_mutable()); } -inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { - // NOTE: it is necessary to cast to container then to base - // so that the FFI handle uses the ContainerBase address. - auto ptr = reinterpret_cast( - TVMFFINDArrayGetDLTensorPtr(static_cast(const_cast(nd.get())))); - return ptr; -} - -inline TVMArrayHandle ObjectHandleToTVMArrayHandle(Object* handle) { - return reinterpret_cast( - TVMFFINDArrayGetDLTensorPtr(static_cast(handle))); -} - -inline Object* TVMArrayHandleToObjectHandle(void* handle) { - // NOTE: legacy patch here for TFM FFI - return reinterpret_cast(reinterpret_cast(handle) - sizeof(TVMFFIObject)); -} - -inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - ffi::details::ObjectUnsafe::DecRefObjectHandle(TVMArrayHandleToObjectHandle(handle)); -} - /*! \brief Magic number for NDArray file */ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; @@ -271,10 +244,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - ICHECK_EQ( - TVMArrayCopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), - 0) - << TVMGetLastError(); + NDArray::CopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); } diff --git a/include/tvm/runtime/nvtx.h b/include/tvm/runtime/nvtx.h index db99154b0b7c..289837c1fda1 100644 --- a/include/tvm/runtime/nvtx.h +++ b/include/tvm/runtime/nvtx.h @@ -19,7 +19,7 @@ #ifndef TVM_RUNTIME_NVTX_H_ #define TVM_RUNTIME_NVTX_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index c686b54d096a..6ce95eea1e83 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 235bdcf3e32f..6da06c119171 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include #include @@ -41,286 +41,6 @@ namespace runtime { using ffi::Any; using ffi::AnyView; -/*! - * \brief Utility function to convert legacy ffi::AnyView to AnyView - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { - TVMFFIAny res; - // clear first to ensure consistent hash - res.v_uint64 = 0; - switch (type_code) { - case kTVMArgInt: { - res.type_index = ffi::TypeIndex::kTVMFFIInt; - res.v_int64 = value.v_int64; - return res; - } - case kTVMArgFloat: { - res.type_index = ffi::TypeIndex::kTVMFFIFloat; - res.v_float64 = value.v_float64; - return res; - } - case kTVMOpaqueHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIOpaquePtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMNullptr: { - res.type_index = ffi::TypeIndex::kTVMFFINone; - return res; - } - case kTVMDataType: { - res.type_index = ffi::TypeIndex::kTVMFFIDataType; - res.v_dtype = value.v_type; - return res; - } - case kDLDevice: { - res.type_index = ffi::TypeIndex::kTVMFFIDevice; - res.v_device = value.v_device; - return res; - } - case kTVMDLTensorHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIDLTensorPtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMObjectHandle: { - res.v_obj = static_cast(value.v_handle); - res.type_index = res.v_obj->type_index; - return res; - } - case kTVMModuleHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIModule; - res.v_obj = static_cast(value.v_handle); - return res; - } - case kTVMPackedFuncHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIFunction; - res.v_obj = static_cast(value.v_handle); - return res; - } - case kTVMStr: { - res.type_index = ffi::TypeIndex::kTVMFFIRawStr; - res.v_c_str = value.v_str; - return res; - } - case kTVMBytes: { - res.type_index = ffi::TypeIndex::kTVMFFIByteArrayPtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMNDArrayHandle: { - res.type_index = ffi::TypeIndex::kTVMFFINDArray; - res.v_obj = reinterpret_cast(TVMArrayHandleToObjectHandle(value.v_handle)); - return res; - } - case kTVMArgBool: { - res.type_index = ffi::TypeIndex::kTVMFFIBool; - res.v_int64 = value.v_int64; - return res; - } - case kTVMObjectRValueRefArg: { - res.type_index = ffi::TypeIndex::kTVMFFIObjectRValueRef; - res.v_ptr = value.v_handle; - return res; - } - default: { - LOG(FATAL) << "Unsupported type code: " << type_code; - TVM_FFI_UNREACHABLE(); - } - } -} - -/*! - * \brief Utility function to convert legacy ffi::AnyView to AnyView - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline AnyView LegacyTVMArgValueToAnyView(TVMValue value, int type_code) { - return AnyView::CopyFromTVMFFIAny(LegacyTVMArgValueToFFIAny(value, type_code)); -} - -/*! - * \brief Utility function to convert legacy ffi::AnyView to Any - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline Any MoveLegacyTVMArgValueToAny(TVMValue value, int type_code) { - return ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(LegacyTVMArgValueToFFIAny(value, type_code)); -} - -/* - * \brief Convert AnyView to legacy TVMValue and type_code - * \param src The AnyView to convert - * \param value The TVMValue to store the result - * \param type_code The type code to store the result - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline void AnyViewToLegacyTVMArgValue(TVMFFIAny src, TVMValue* value, int* type_code) { - switch (src.type_index) { - case ffi::TypeIndex::kTVMFFIBool: { - type_code[0] = kTVMArgBool; - value[0].v_int64 = src.v_int64; - break; - } - case ffi::TypeIndex::kTVMFFIInt: { - type_code[0] = kDLInt; - value[0].v_int64 = src.v_int64; - break; - } - case ffi::TypeIndex::kTVMFFIFloat: { - type_code[0] = kDLFloat; - value[0].v_float64 = src.v_float64; - break; - } - case ffi::TypeIndex::kTVMFFIOpaquePtr: { - type_code[0] = kTVMOpaqueHandle; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFINone: { - type_code[0] = kTVMNullptr; - break; - } - case ffi::TypeIndex::kTVMFFIDataType: { - type_code[0] = kTVMDataType; - value[0].v_type = src.v_dtype; - break; - } - case ffi::TypeIndex::kTVMFFIDevice: { - type_code[0] = kDLDevice; - value[0].v_device = src.v_device; - break; - } - case ffi::TypeIndex::kTVMFFIDLTensorPtr: { - type_code[0] = kTVMDLTensorHandle; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFIRawStr: { - type_code[0] = kTVMStr; - value[0].v_str = src.v_c_str; - break; - } - case ffi::TypeIndex::kTVMFFIByteArrayPtr: { - type_code[0] = kTVMBytes; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFINDArray: { - type_code[0] = kTVMNDArrayHandle; - value[0].v_handle = ObjectHandleToTVMArrayHandle(reinterpret_cast(src.v_obj)); - break; - } - case ffi::TypeIndex::kTVMFFIModule: { - type_code[0] = kTVMModuleHandle; - value[0].v_handle = src.v_obj; - break; - } - case ffi::TypeIndex::kTVMFFIFunction: { - type_code[0] = kTVMPackedFuncHandle; - value[0].v_handle = src.v_obj; - break; - } - case ffi::TypeIndex::kTVMFFIObjectRValueRef: { - type_code[0] = kTVMObjectRValueRefArg; - value[0].v_handle = src.v_ptr; - break; - } - default: { - if (src.type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - type_code[0] = kTVMObjectHandle; - value[0].v_handle = src.v_obj; - break; - } - LOG(FATAL) << "Unsupported type index: " << src.type_index; - } - } -} - -/* - * \brief Move Any to legacy TVMValue and type_code - * \param src The Any to move - * \param value The TVMValue to store the result - * \param type_code The type code to store the result - */ -inline void MoveAnyToLegacyTVMValue(Any&& src, TVMValue* value, int* type_code) { - TVMFFIAny val = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src)); - // NOTE: conversion rule is the same as AnyViewToLegacyTVMArgValue - AnyViewToLegacyTVMArgValue(val, value, type_code); -} - -/*! - * \brief Translate legacy ffi::PackedArgs to PackedArgs - * \param value The TVMValue array - * \param type_code The type code array - * \param num_args The number of arguments - * \param dst The destination AnyView array - */ -inline void LegacyTVMArgsToPackedArgs(const TVMValue* value, const int* type_code, int num_args, - AnyView* dst) { - for (int i = 0; i < num_args; ++i) { - dst[i] = LegacyTVMArgValueToAnyView(value[i], type_code[i]); - } -} - -/*! - * \brief Translate legacy ffi::PackedArgs to PackedArgs - * \param args The AnyView array - * \param num_args The number of arguments - * \param value The TVMValue array - * \param type_code The type code array - */ -inline void PackedArgsToLegacyTVMArgs(const AnyView* args, int num_args, TVMValue* value, - int* type_code) { - for (int i = 0; i < num_args; ++i) { - AnyViewToLegacyTVMArgValue(args[i].CopyToTVMFFIAny(), value + i, type_code + i); - } -} - -/*! - * \brief Convert argument type code to string. - * \param type_code The input type code. - * \return The corresponding string repr. - */ -inline const char* ArgTypeCode2Str(int type_code) { - switch (type_code) { - case kDLInt: - return "int"; - case kTVMArgBool: - return "bool"; - case kDLUInt: - return "uint"; - case kDLFloat: - return "float"; - case kTVMStr: - return "str"; - case kTVMBytes: - return "bytes"; - case kTVMOpaqueHandle: - return "handle"; - case kTVMNullptr: - return "NULL"; - case kTVMDLTensorHandle: - return "ArrayHandle"; - case kTVMDataType: - return "DLDataType"; - case kDLDevice: - return "DLDevice"; - case kTVMPackedFuncHandle: - return "FunctionHandle"; - case kTVMModuleHandle: - return "ModuleHandle"; - case kTVMNDArrayHandle: - return "NDArrayContainer"; - case kTVMObjectHandle: - return "Object"; - case kTVMObjectRValueRefArg: - return "ObjectRValueRefArg"; - default: - LOG(FATAL) << "unknown type_code=" << static_cast(type_code); - } - throw; -} - namespace details { template diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 63814a637592..1e950f9f7a95 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -26,11 +26,11 @@ #include #include -#include +#include +#include #include #include #include -#include #include #include @@ -133,7 +133,7 @@ class Timer : public ObjectRef { * }; * TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); * - * TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { + * TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { * return Timer(make_object()); * }); * \endcode diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h deleted file mode 100644 index 94463dc7255f..000000000000 --- a/include/tvm/runtime/registry.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/registry.h - * \brief This file defines the TVM global function registry. - */ -#ifndef TVM_RUNTIME_REGISTRY_H_ -#define TVM_RUNTIME_REGISTRY_H_ - -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! \brief A class that wraps a Python object and preserves its ownership. - - * This class is used to wrap a PyObject* from the Python API and preserve its ownership. - * Allows for the creation of strong references to Python objects, which prevent them from being - * garbage-collected as long as the wrapper object exists. - */ -class WrappedPythonObject { - public: - /*! \brief Construct a wrapper that doesn't own anything */ - WrappedPythonObject() : python_obj_(nullptr) {} - - /*! \brief Conversion constructor from nullptr */ - explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {} - - /*! \brief Take ownership of a python object - * - * A new strong reference is created for the underlying python - * object. - * - * \param python_obj A PyObject* from the Python.h API. A new - * strong reference is created using Py_IncRef. - */ - explicit WrappedPythonObject(void* python_obj); - - /*! \brief Drop ownership of a python object - * - * Removes the strong reference held by the wrapper. - */ - ~WrappedPythonObject(); - - WrappedPythonObject(WrappedPythonObject&&); - WrappedPythonObject& operator=(WrappedPythonObject&&); - - WrappedPythonObject(const WrappedPythonObject&); - WrappedPythonObject& operator=(const WrappedPythonObject&); - WrappedPythonObject& operator=(std::nullptr_t); - - operator bool() { return python_obj_; } - - void* raw_pointer() { return python_obj_; } - - private: - void* python_obj_ = nullptr; -}; - -/*! - * \brief Register a function globally. - * \code - * TVM_REGISTER_GLOBAL("MyPrint") - * .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - * }); - * \endcode - */ -#define TVM_REGISTER_GLOBAL TVM_FFI_REGISTER_GLOBAL - -#define TVM_STRINGIZE_DETAIL(x) #x -#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) -#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) -/*! - * \brief Macro to include current line as string - */ -#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_REGISTRY_H_ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h index dc9d87025382..afaaea9e41e7 100644 --- a/include/tvm/runtime/relax_vm/executable.h +++ b/include/tvm/runtime/relax_vm/executable.h @@ -23,9 +23,9 @@ #ifndef TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ #define TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ +#include #include #include -#include #include #include diff --git a/include/tvm/runtime/relax_vm/ndarray_cache_support.h b/include/tvm/runtime/relax_vm/ndarray_cache_support.h index f595d81ffe7e..579fbf306f68 100644 --- a/include/tvm/runtime/relax_vm/ndarray_cache_support.h +++ b/include/tvm/runtime/relax_vm/ndarray_cache_support.h @@ -20,8 +20,8 @@ #define TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_ #include +#include #include -#include #include #include diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index b35cad368832..2cfd1de44dde 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -27,7 +27,7 @@ #include #include -#include +#include #include namespace dmlc { diff --git a/include/tvm/support/parallel_for.h b/include/tvm/support/parallel_for.h index 8bd2e6b825ab..aa9da30d8f1c 100644 --- a/include/tvm/support/parallel_for.h +++ b/include/tvm/support/parallel_for.h @@ -24,7 +24,7 @@ #ifndef TVM_SUPPORT_PARALLEL_FOR_H_ #define TVM_SUPPORT_PARALLEL_FOR_H_ -#include +#include #include #include diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 2035a511c1bb..c057422a0266 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -415,17 +415,15 @@ TVM_DLL const Op& tvm_thread_invariant(); * type codes are explicitly allocated. * * return_type tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, + * TVMFFIAny* args_stack, * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); * const ffi::Function* f = env->GetFuncFromEnv(name); - * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * ffi::Any(value_stack + end, tcode_stack + end)); + * f->CallPacked(ffi::PackedArgs(args_stack[begin:end]), + * ffi::Any(args_stack + end)); * // return type can be int, float, handle. - * return cast(return_type, load_return_from(tcode_stack + end)) + * return cast(return_type, load_return_from(args_stack + end)) * } */ TVM_DLL const Op& tvm_call_packed_lowered(); @@ -451,17 +449,15 @@ TVM_DLL const Op& tvm_call_cpacked_lowered(); * (end - 1) value on the stack. * * return_type tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, + * TVMFFIAny* args_stack, * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); * const ffi::Function* f = env->GetFuncFromEnv(name); - * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * ffi::Any(value_stack + end, tcode_stack + end)); + * f->CallPacked(ffi::PackedArgs(args_stack[begin:end]), + * ffi::Any(args_stack + end)); * // return type can be int, float, handle. - * return cast(return_type, load_return_from(tcode_stack + end)) + * return cast(return_type, load_return_from(args_stack + end)) * } */ TVM_DLL const Op& tvm_call_trace_packed_lowered(); diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 078b41e32798..5f058f7d5e4c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index b80d4456c0be..eb64d87f9518 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -166,9 +166,9 @@ TVM_DLL Pass InstrumentBoundCheckers(); * f() * * if num_packed_args is not zero: - * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * f(void *, TVMFFIAny* packed_args, int num_packed_args, * api_arg_k, api_arg_k+1, ... api_arg_n, - * TVMValue* out_ret_val, int* out_ret_tcode) + * TVMFFIAny* out_ret_val) * * where n == len(api_args), k == num_packed_args * diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 5d04cf13e693..88581b1cb4f4 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - headers = ["tvm/runtime/registry.h"] + headers = ["tvm/ffi/function.h"] if "relu" in func_name: headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h") diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index f7df8d441743..ab416ef14176 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2087,7 +2087,7 @@ def extern( out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). + TVM runtime using `TVM_FFI_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). Parameters ---------- diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index a07193ea9852..f0d1bdeb0a76 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -17,6 +17,6 @@ """FFI APIs for tvm.runtime""" import tvm._ffi -# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix. -# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") +# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "runtime" prefix. +# e.g. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") tvm._ffi._init_api("runtime", __name__) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 395496d16be7..600206d42583 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -24,7 +24,7 @@ # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. # They will be overriden via _init_api to the ones registered -# via TVM_REGISTER_GLOBAL in the compiler mode. +# via TVM_FFI_REGISTER_GLOBAL in the compiler mode. def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" @@ -45,6 +45,6 @@ def LoadJSON(json_str): raise RuntimeError("Do not support object serialization in runtime only mode") -# Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix. -# e.g. TVM_REGISTER_GLOBAL("node.AsRepr") +# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "node" prefix. +# e.g. TVM_FFI_REGISTER_GLOBAL("node.AsRepr") tvm._ffi._init_api("node", __name__) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 8dcb67190a93..f0a317659d3a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -21,7 +21,7 @@ * \file tvm/arith/analyzer.cc */ #include -#include +#include #include #include @@ -269,7 +269,7 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") +TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { using ffi::Function; using ffi::TypedFunction; diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index d52ae7e6fde3..b8b5d6482428 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -22,7 +22,7 @@ * \brief Utility to deduce bound of expression */ #include -#include +#include #include #include @@ -402,7 +402,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, return DeduceBound(v, e, hmap, rmap); } -TVM_REGISTER_GLOBAL("arith.DeduceBound") +TVM_FFI_REGISTER_GLOBAL("arith.DeduceBound") .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, const Map relax_map) { return DeduceBound(v, cond, hint_map, relax_map); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 7409ecc6f37e..a440b52074e8 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -21,7 +21,7 @@ * \file tvm/arith/const_int_bound.cc */ #include -#include +#include #include #include @@ -51,7 +51,7 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); +TVM_FFI_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index b496e7fefca5..303360002e03 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -69,6 +69,6 @@ Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { return results; } -TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); +TVM_FFI_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 4d3164cbd382..0dcbc7623590 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -22,7 +22,7 @@ * \brief Utility to detect patterns in the expression. */ #include -#include +#include #include #include #include @@ -290,9 +290,9 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); +TVM_FFI_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); -TVM_REGISTER_GLOBAL("arith.DetectClipBound") +TVM_FFI_REGISTER_GLOBAL("arith.DetectClipBound") .set_body_typed([](const PrimExpr& e, const Array& vars) { return DetectClipBound(e, vars); }); diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 13243ddb5e44..5f9d78003001 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,7 +21,7 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ -#include +#include #include #include #include @@ -162,8 +162,8 @@ Map> DomainTouchedAccessMap(const PrimFunc& func) { return ret; } -TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); -TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); +TVM_FFI_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); +TVM_FFI_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 8c314992ab49..01e7a3096927 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include @@ -195,15 +195,16 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds") +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, Array upper) { return IntGroupBounds(coef, lower, equal, upper); }); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_from_range") + .set_body_typed(IntGroupBounds::FromRange); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.size() == 1 || args.size() == 2); auto bounds = args[0].cast(); @@ -243,7 +244,7 @@ IntConstraints::IntConstraints(Array variables, Map ranges, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); -TVM_REGISTER_GLOBAL("arith.IntConstraints") +TVM_FFI_REGISTER_GLOBAL("arith.IntConstraints") .set_body_typed([](Array variables, Map ranges, Array relations) { return IntConstraints(variables, ranges, relations); }); @@ -288,7 +289,7 @@ IntConstraintsTransform IntConstraintsTransform::operator+( TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); -TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") +TVM_FFI_REGISTER_GLOBAL("arith.IntConstraintsTransform") .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 94a2a369a664..d3b7b30628a1 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -57,7 +57,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -1192,42 +1192,42 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "[" << op->min_value << ", " << op->max_value << ']'; }); -TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); +TVM_FFI_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); -TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); +TVM_FFI_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); -TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); +TVM_FFI_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); +TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); +TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); -TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); -TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); -TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); +TVM_FFI_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); +TVM_FFI_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); +TVM_FFI_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); } // namespace arith } // namespace tvm diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 5b1feef4e608..2aa0ca6b6425 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -47,7 +47,7 @@ IterMark::IterMark(PrimExpr source, PrimExpr extent) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { +TVM_FFI_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { return IterMark(source, extent); }); @@ -92,7 +92,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr ex data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterSplitExpr") +TVM_FFI_REGISTER_GLOBAL("arith.IterSplitExpr") .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { return IterSplitExpr(source, lower_factor, extent, scale); }); @@ -114,7 +114,7 @@ IterSumExpr::IterSumExpr(Array args, PrimExpr base) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterSumExpr") +TVM_FFI_REGISTER_GLOBAL("arith.IterSumExpr") .set_body_typed([](Array args, PrimExpr base) { return IterSumExpr(args, base); }); @@ -1513,7 +1513,7 @@ IterMapResult DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { @@ -1538,7 +1538,7 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iter return rewriter.RewriteToNormalizedIterSum(index); } -TVM_REGISTER_GLOBAL("arith.NormalizeToIterSum") +TVM_FFI_REGISTER_GLOBAL("arith.NormalizeToIterSum") .set_body_typed([](PrimExpr index, const Map& input_iters) { arith::Analyzer ana; return NormalizeToIterSum(index, input_iters, &ana); @@ -2133,7 +2133,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { return normalizer.Convert(expr); } -TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); +TVM_FFI_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, @@ -2162,7 +2162,7 @@ Array IterMapSimplify(const Array& indices, const Map& indices, const Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { @@ -2495,7 +2495,7 @@ Array> SubspaceDivide(const Array& bindings, return results; } -TVM_REGISTER_GLOBAL("arith.SubspaceDivide") +TVM_FFI_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, const Array& sub_iters, const PrimExpr& predicate, int check_level, bool simplify_trivial_iterators) { @@ -2634,7 +2634,7 @@ Map InverseAffineIterMap(const Array& iter_map, return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); } -TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_FFI_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); TVM_REGISTER_NODE_TYPE(IterMapResultNode); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 197e5ec8b868..fa4891d5a00b 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -22,7 +22,7 @@ * \brief Modular set analysis */ #include -#include +#include #include #include #include @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); +TVM_FFI_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index 40c7ab3c54ac..a1a9768110ed 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -22,7 +22,7 @@ * \brief Utility to deduce bound of expression */ #include -#include +#include #include #include #include @@ -212,7 +212,8 @@ PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameter return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); } -TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); +TVM_FFI_REGISTER_GLOBAL("arith.NarrowPredicateExpression") + .set_body_typed(NarrowPredicateExpression); } // namespace arith } // namespace tvm diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 4f4d7e18578f..e514ad1b1ad7 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -272,7 +272,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); } -TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); +TVM_FFI_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); TVM_REGISTER_NODE_TYPE(PresburgerSetNode); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index fb6250a778ef..4d90c61ea3cb 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -24,8 +24,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -454,7 +454,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") +TVM_FFI_REGISTER_GLOBAL("arith.SolveLinearEquations") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveLinearEquations(args[0].cast()); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 0e5e6d485e74..62f314d1902f 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -24,8 +24,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -535,7 +535,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ return transform; } -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { IntConstraints problem; PartialSolvedInequalities ret_ineq; @@ -553,7 +553,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesToRange(args[0].cast()); @@ -568,7 +568,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") } }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesDeskewRange(args[0].cast()); diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 26f46bd032af..d38beab5b4ed 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -1431,14 +1431,14 @@ TVM_REGISTER_NODE_TYPE(MSCGraphNode); TVM_REGISTER_NODE_TYPE(WeightGraphNode); -TVM_REGISTER_GLOBAL("msc.core.MSCTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensor") .set_body_typed([](const String& name, const DataType& dtype, const String& layout, const Array& shape, const String& alias, const Array& prims) -> MSCTensor { return MSCTensor(name, dtype, layout, shape, alias, prims); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorToJson") .set_body_typed([](const MSCTensor& tensor) -> String { const auto& tensor_json = tensor->ToJson(); std::ostringstream os; @@ -1447,10 +1447,10 @@ TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") return os.str(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorFromJson") .set_body_typed([](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJoint") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& optype, const Map& attrs, const Array& scope, const Array& parents, @@ -1464,7 +1464,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint") weights); }); -TVM_REGISTER_GLOBAL("msc.core.MSCPrim") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCPrim") .set_body_typed([](Integer index, const String& name, const String& optype, const Map& attrs, const Array& parents) -> MSCPrim { Array b_parents; @@ -1474,7 +1474,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCPrim") return MSCPrim(index->value, name, optype, b_parents, attrs); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJoint") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, const Array parents, const Map& attrs, @@ -1490,108 +1490,109 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJoint") b_friends); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") .set_body_typed([](const WeightJoint& node, const String& key, const String& value) { node->attrs.Set(key, value); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraph") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraph") .set_body_typed([](const String& name, const Array& nodes, const Array& input_names, const Array& output_names, const Array& prims) -> MSCGraph { return MSCGraph(name, nodes, input_names, output_names, prims); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraph") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraph") .set_body_typed([](const MSCGraph& graph, const Map>& main_wtypes, const Map& relation_wtypes) -> WeightGraph { return WeightGraph(graph, main_wtypes, relation_wtypes); }); // MSC Graph APIS -TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasNode(name)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindNode(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCPrim { return graph->FindPrim(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasTensor(name)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCTensor { return graph->FindTensor(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") .set_body_typed([](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { tensor->alias = alias; graph->tensor_alias.Set(alias, tensor->name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindProducer(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindConsumers") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindConsumers") .set_body_typed([](const MSCGraph& graph, const String& name) -> Array { return graph->FindConsumers(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphInputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphInputAt") .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { return graph->InputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphOutputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphOutputAt") .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphGetInputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetInputs") .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetInputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphGetOutputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetOutputs") .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphToJson").set_body_typed([](const MSCGraph& graph) -> String { - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - dmlc::JSONWriter writer(&os); - graph_json.Save(&writer); - return os.str(); -}); +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToJson") + .set_body_typed([](const MSCGraph& graph) -> String { + const auto& graph_json = graph->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + graph_json.Save(&writer); + return os.str(); + }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") .set_body_typed([](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") .set_body_typed([](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); // Weight Graph APIS -TVM_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") .set_body_typed([](const WeightGraph& graph, const String& name) -> Bool { return Bool(graph->HasNode(name)); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") .set_body_typed([](const WeightGraph& graph, const String& name) -> WeightJoint { return graph->FindNode(name); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphToJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToJson") .set_body_typed([](const WeightGraph& graph) -> String { const auto& graph_json = graph->ToJson(); std::ostringstream os; @@ -1600,69 +1601,69 @@ TVM_REGISTER_GLOBAL("msc.core.WeightGraphToJson") return os.str(); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") .set_body_typed([](const String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") .set_body_typed([](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointInputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointInputAt") .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") .set_body_typed([](const MSCJoint& node, const String& wtype) -> MSCTensor { return node->WeightAt(wtype); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") .set_body_typed([](const MSCJoint& node) -> Array { return node->GetInputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetOutputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetOutputs") .set_body_typed([](const MSCJoint& node) -> Array { return node->GetOutputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetWeights") .set_body_typed([](const MSCJoint& node) -> Map { return node->weights; }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") .set_body_typed([](const MSCJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") .set_body_typed([](const MSCJoint& node) -> Map { return node->attrs; }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") .set_body_typed([](const WeightJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") .set_body_typed([](const WeightJoint& node) -> Map { return node->attrs; }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") .set_body_typed([](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorDimAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDimAt") .set_body_typed([](const MSCTensor& tensor, const String& axis) -> Integer { return tensor->DimAt(axis); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") .set_body_typed([](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") .set_body_typed([](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }); -TVM_REGISTER_GLOBAL("msc.core.PruneWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.PruneWeights") .set_body_typed([](const MSCGraph& graph, const Map& pruned_tensors) -> MSCGraph { return PruneWeights(graph, pruned_tensors); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 030df82dc6ee..2550f5652fc7 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -834,7 +834,7 @@ void WeightsExtractor::VisitExpr_(const CallNode* op) { } } -TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax") +TVM_FFI_REGISTER_GLOBAL("msc.core.BuildFromRelax") .set_body_typed([](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { auto builder = GraphBuilder(module, entry_name, options); @@ -844,7 +844,7 @@ TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax") return builder.Build(func); }); -TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.GetRelaxWeights") .set_body_typed([](const IRModule& module, const String& entry_name) -> Map { const auto& func = Downcast(module->Lookup(entry_name)); diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index d34972639a7b..fc6000a20f3d 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -305,20 +305,20 @@ const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Ge bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } -TVM_REGISTER_GLOBAL("msc.core.RegisterPlugin") +TVM_FFI_REGISTER_GLOBAL("msc.core.RegisterPlugin") .set_body_typed([](const String& name, const String& json_str) { PluginRegistry::Global()->Register(name, json_str); }); -TVM_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { +TVM_FFI_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { return ListPluginNames(); }); -TVM_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { +TVM_FFI_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { return GetPlugin(name); }); -TVM_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { +TVM_FFI_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { return Bool(IsPlugin(name)); }); diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 523db32b3a8d..0225ff319097 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -154,7 +154,7 @@ Pass BindNamedParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindNamedParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); } // namespace transform diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index ca91b3424e0d..b554e08ab820 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -132,7 +132,7 @@ Pass BindShape(const String& entry_name) { return CreateModulePass(pass_func, 0, "BindShape", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindShape").set_body_typed(BindShape); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindShape").set_body_typed(BindShape); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 1bcbe076fc59..1eabf3306f36 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -231,7 +231,7 @@ Pass FuseTuple(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "FuseTuple", {}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index e892d3b36a42..a91eb590af26 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -184,7 +184,7 @@ Pass InlineParams(const String& entry_name) { return CreateModulePass(pass_func, 0, "InlineParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 6f3a29346bfd..4755ebf38960 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -101,7 +101,7 @@ Pass SetBYOCAttrs(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 80416bafd0f2..dd87e60e7b80 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -1359,7 +1359,7 @@ Pass SetExprLayout(bool allow_missing, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetExprLayout", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 4c36e2ba2754..4d0cc0314e18 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -324,7 +324,7 @@ Pass SetRelaxExprName(const String& entry_name, const String& target, return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 3c53e32170c3..d03f3ba82b28 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -523,27 +523,27 @@ const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(GetStructInfo(expr))->dtype; } -TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); -TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); -TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") .set_body_typed([](const String& key, const String& value) -> Span { return SpanUtils::CreateWithAttr(key, value); }); -TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanSetAttr") .set_body_typed([](const Span& span, const String& key, const String& value) -> Span { return SpanUtils::SetAttr(span, key, value); }); -TVM_REGISTER_GLOBAL("msc.core.CompareVersion") +TVM_FFI_REGISTER_GLOBAL("msc.core.CompareVersion") .set_body_typed([](const Array& given_version, const Array& target_version) -> Integer { return Integer(CommonUtils::CompareVersion(given_version, target_version)); }); -TVM_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { +TVM_FFI_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { return CommonUtils::ToAttrKey(key); }); diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 9506d4eac818..4bceb76d4699 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -150,7 +150,7 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 0271b63112c9..8b85f2e88f04 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -574,7 +574,7 @@ const Map TensorRTCodeGen::GetStepCtx() { return step_ctx; } -TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); @@ -582,7 +582,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") return codegen.GetSources(print_config); }); -TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { #ifdef TENSORRT_ROOT_DIR return TENSORRT_ROOT_DIR; #else @@ -618,7 +618,7 @@ Array MSCTensorRTCompiler(Array functions, return compiled_functions; } -TVM_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); +TVM_FFI_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index fc0e538252ca..67f453268e2a 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -913,7 +913,7 @@ Pass TransformTensorRT(const String& config) { return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } -TVM_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); +TVM_FFI_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 547c1c22ba75..228efa4381ee 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -151,7 +151,7 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 1d6d74d7e43a..53d1bc0562fc 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -210,7 +210,7 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index d16bb5d92440..02904c3bd9c8 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -883,7 +883,7 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 215341af0133..59b99f22c7ce 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -492,7 +492,7 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index e1d3c9960f6d..610fbc4c3282 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -213,12 +213,12 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { stack_.func_end("infer_output"); // register funcs - stack_.func_call("TVM_REGISTER_GLOBAL") + stack_.func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name)) .method_call("set_body_typed") .call_arg("InferStructInfo" + plugin->name) .line() - .func_call("TVM_REGISTER_GLOBAL") + .func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name)) .method_call("set_body_typed") .call_arg("InferLayout" + plugin->name) @@ -260,7 +260,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { CodeGenCompute(plugin, "cpu"); stack_.cond_end().func_end(); // register the compute - stack_.func_call("TVM_REGISTER_GLOBAL") + stack_.func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr(plugin->name)) .method_call("set_body") .call_arg(func_name) @@ -393,7 +393,7 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 9de36b0a28af..3a54085c2290 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -43,7 +43,7 @@ Map> CollectCallMap(const IRModule& mod) { return call_map; } -TVM_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); +TVM_FFI_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); } // namespace ir } // namespace tvm diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 9e43e33a6c4a..877530f4c378 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -21,9 +21,9 @@ * \file src/ir/apply_pass_to_function.cc * \brief Utility transformation that applies an inner pass to a subset of an IRModule */ +#include #include #include -#include #include #include @@ -130,7 +130,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, return CreateModulePass(pass_func, 0, pass_name, {}); } -TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); +TVM_FFI_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); } // namespace transform } // namespace tvm diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index fd87c2bc8e0c..52a2cceeaf79 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -20,8 +20,8 @@ /*! * \file attrs.cc */ +#include #include -#include #include "attr_functor.h" @@ -73,11 +73,11 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { +TVM_FFI_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { return attrs->dict; }); -TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { +TVM_FFI_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index ec11f2c04f6c..70197074317d 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -33,7 +33,7 @@ namespace tvm { /* Diagnostic */ TVM_REGISTER_NODE_TYPE(DiagnosticNode); -TVM_REGISTER_GLOBAL("diagnostics.Diagnostic") +TVM_FFI_REGISTER_GLOBAL("diagnostics.Diagnostic") .set_body_typed([](int level, Span span, String message) { return Diagnostic(static_cast(level), span, message); }); @@ -106,7 +106,7 @@ TVM_DLL DiagnosticRenderer::DiagnosticRenderer( data_ = std::move(n); } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") .set_body_typed([](ffi::TypedFunction renderer) { return DiagnosticRenderer(renderer); }); @@ -134,7 +134,7 @@ void DiagnosticContext::Render() { } } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") .set_body_typed([](DiagnosticRenderer renderer, DiagnosticContext ctx) { renderer.Render(ctx); }); @@ -147,7 +147,7 @@ DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRen data_ = std::move(n); } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContext") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContext") .set_body_typed([](const IRModule& module, const DiagnosticRenderer& renderer) { return DiagnosticContext(module, renderer); }); @@ -157,12 +157,12 @@ void DiagnosticContext::Emit(const Diagnostic& diagnostic) { (*this)->diagnostics.push_back(diagnostic); } -TVM_REGISTER_GLOBAL("diagnostics.Emit") +TVM_FFI_REGISTER_GLOBAL("diagnostics.Emit") .set_body_typed([](DiagnosticContext ctx, const Diagnostic& diagnostic) { return ctx.Emit(diagnostic); }); -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") .set_body_typed([](DiagnosticContext context) { return context.Render(); }); /*! \brief Emit a diagnostic. */ @@ -195,7 +195,7 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } -TVM_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { +TVM_FFI_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { return DiagnosticContext::Default(module); }); @@ -311,11 +311,13 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cerr); }); +TVM_FFI_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { + return TerminalRenderer(std::cerr); +}); -TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); +TVM_FFI_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); -TVM_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { tvm::ffi::Function::RemoveGlobal(OVERRIDE_RENDERER); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 9713f88f7ddd..ce40df21eb9a 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -20,8 +20,8 @@ /*! * \file env_func.cc */ +#include #include -#include #include namespace tvm { @@ -47,15 +47,15 @@ ObjectPtr CreateEnvNode(const std::string& name) { EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { EnvFunc env = args[0].cast(); ICHECK_GE(args.size(), 1); env->func.CallPacked(args.Slice(1), rv); }); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { return n->func; }); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index de665dcd22b3..387572f6427b 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -22,9 +22,9 @@ * \brief The expression AST nodes for the common IR infra. */ #include +#include #include #include -#include #include #include @@ -64,7 +64,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { +TVM_FFI_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { return IntImm(dtype, value, span); }); @@ -115,7 +115,7 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { +TVM_FFI_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { return FloatImm(dtype, value, span); }); @@ -128,9 +128,9 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { return Range(make_object(min, extent, span)); } -TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); +TVM_FFI_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range") +TVM_FFI_REGISTER_GLOBAL("ir.Range") .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { if (end.defined()) { return Range(begin, end.value(), span); @@ -151,11 +151,11 @@ GlobalVar::GlobalVar(String name_hint, Type type, Span span) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { return GlobalVar(name, type); }); -TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { +TVM_FFI_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; return ss.str(); diff --git a/src/ir/function.cc b/src/ir/function.cc index 8f543b03260c..66d66e3c8133 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -21,19 +21,21 @@ * \file src/ir/function.cc * \brief The function data structure. */ +#include #include #include #include -#include #include namespace tvm { -TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; }); +TVM_FFI_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { + return func->attrs; +}); -TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { @@ -45,7 +47,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") .set_body_typed([](ffi::RValueRef func_ref, Map attr_map) -> BaseFunc { BaseFunc func = *std::move(func_ref); @@ -61,7 +63,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") TVM_FFI_UNREACHABLE(); }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") .set_body_typed([](ffi::RValueRef func_ref, String key) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 6abac574e1b7..3df9ae00fb53 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -25,7 +25,7 @@ #include namespace tvm { TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); -TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); return n; }); @@ -39,7 +39,8 @@ VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { } TVM_REGISTER_NODE_TYPE(VDeviceNode); -TVM_REGISTER_GLOBAL("ir.VDevice").set_body_typed([](Target tgt, int dev_id, MemoryScope mem_scope) { - return VDevice(tgt, dev_id, mem_scope); -}); +TVM_FFI_REGISTER_GLOBAL("ir.VDevice") + .set_body_typed([](Target tgt, int dev_id, MemoryScope mem_scope) { + return VDevice(tgt, dev_id, mem_scope); + }); } // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 3d3b8919916f..1b47c1a89639 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -23,7 +23,7 @@ */ #include "tvm/ir/global_var_supply.h" -#include +#include #include @@ -92,24 +92,23 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { return GlobalVarSupply(std::move(mod)); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array& mods) { - return GlobalVarSupply(mods); -}); +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules") + .set_body_typed([](const Array& mods) { return GlobalVarSupply(mods); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") .set_body_method(&GlobalVarSupplyNode::FreshGlobal); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") .set_body_method(&GlobalVarSupplyNode::UniqueGlobalFor); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") .set_body_method(&GlobalVarSupplyNode::ReserveGlobalVar); } // namespace tvm diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index ad66f2944891..a273245c1b64 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -22,10 +22,10 @@ * \brief Infrastructure for instrumentation. */ #include +#include #include #include #include -#include #include @@ -175,7 +175,7 @@ void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); -TVM_REGISTER_GLOBAL("instrument.PassInstrument") +TVM_FFI_REGISTER_GLOBAL("instrument.PassInstrument") .set_body_typed( [](String name, ffi::TypedFunction enter_pass_ctx, ffi::TypedFunction exit_pass_ctx, @@ -308,9 +308,9 @@ String RenderPassProfiles() { return os.str(); } -TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); +TVM_FFI_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); -TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { PassProfile::EnterPass(pass_info->name); return true; diff --git a/src/ir/module.cc b/src/ir/module.cc index 0223aef4d15b..3166ffba9787 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,12 +21,12 @@ * \brief The global module in TVM. */ #include +#include #include #include #include #include #include -#include #include #include @@ -242,7 +242,7 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, TVM_REGISTER_NODE_TYPE(IRModuleNode); -TVM_REGISTER_GLOBAL("ir.IRModule") +TVM_FFI_REGISTER_GLOBAL("ir.IRModule") .set_body_typed([](tvm::Map funcs, tvm::ObjectRef attrs, Map> global_infos) { auto dict_attrs = [&attrs]() { @@ -260,20 +260,20 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return IRModule(funcs, {}, dict_attrs, global_infos); }); -TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule { IRModule clone = mod; clone.CopyOnWrite(); return clone; }); -TVM_REGISTER_GLOBAL("ir.Module_Add") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { ICHECK(val->IsInstance()); mod->Add(var, Downcast(val), update); return mod; }); -TVM_REGISTER_GLOBAL("ir.Module_Remove") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Remove") .set_body_typed([](IRModule mod, Variant var) -> IRModule { GlobalVar gvar = [&]() { if (auto opt = var.as()) { @@ -289,7 +289,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Remove") return mod; }); -TVM_REGISTER_GLOBAL("ir.Module_Contains") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Contains") .set_body_typed([](IRModule mod, Variant var) -> bool { if (auto opt = var.as()) { return mod->functions.count(opt.value()); @@ -301,55 +301,57 @@ TVM_REGISTER_GLOBAL("ir.Module_Contains") } }); -TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar").set_body_method(&IRModuleNode::GetGlobalVar); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVar").set_body_method(&IRModuleNode::GetGlobalVar); -TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars").set_body_method(&IRModuleNode::GetGlobalVars); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVars").set_body_method(&IRModuleNode::GetGlobalVars); -TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar").set_body_method(&IRModuleNode::ContainGlobalVar); +TVM_FFI_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") + .set_body_method(&IRModuleNode::ContainGlobalVar); -TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); +TVM_FFI_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); -TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") +TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateFunction") .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); -TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") +TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") .set_body_typed([](IRModule mod, String name, Array global_info) { mod->UpdateGlobalInfo(name, global_info); }); -TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { return mod->GetAttrs(); }); -TVM_REGISTER_GLOBAL("ir.Module_WithAttr") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttr") .set_body_typed([](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { return WithAttr(*std::move(mod), key, value); }); -TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithoutAttr") .set_body_typed([](ffi::RValueRef mod, String key) -> IRModule { return WithoutAttr(*std::move(mod), key); }); -TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttrs") .set_body_typed([](ffi::RValueRef mod, Map attr_map) -> IRModule { return WithAttrs(*std::move(mod), attr_map); }); -TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { - return mod->GetAttr(key); -}); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttr") + .set_body_typed([](IRModule mod, String key) -> ObjectRef { + return mod->GetAttr(key); + }); } // namespace tvm diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 087fc82a50f7..e73b0e63e3d0 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -23,7 +23,7 @@ */ #include "tvm/ir/name_supply.h" -#include +#include #include @@ -92,14 +92,15 @@ std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) TVM_REGISTER_NODE_TYPE(NameSupplyNode); -TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { return NameSupply(prefix); }); -TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName").set_body_method(&NameSupplyNode::FreshName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_FreshName").set_body_method(&NameSupplyNode::FreshName); -TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName").set_body_method(&NameSupplyNode::ReserveName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ReserveName").set_body_method(&NameSupplyNode::ReserveName); -TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName").set_body_method(&NameSupplyNode::ContainsName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ContainsName") + .set_body_method(&NameSupplyNode::ContainsName); } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index 70f7528e5e76..4917f8336b1d 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -75,13 +75,13 @@ void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { } // Frontend APIs -TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { return OpRegistry::Global()->ListAllNames(); }); -TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); +TVM_FFI_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); -TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { +TVM_FFI_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { auto op_map = Op::GetAttrMap(attr_name); ffi::Any rv; if (op_map.count(op)) { @@ -90,50 +90,50 @@ TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) - return rv; }); -TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool { +TVM_FFI_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool { return Op::HasAttrMap(attr_name); }); -TVM_REGISTER_GLOBAL("ir.OpSetAttr") +TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttr") .set_body_typed([](Op op, String attr_name, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }); -TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { +TVM_FFI_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }); -TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); op.describe(descr); }); -TVM_REGISTER_GLOBAL("ir.OpAddArgument") +TVM_FFI_REGISTER_GLOBAL("ir.OpAddArgument") .set_body_typed([](Op op, String name, String type, String description) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.add_argument(name, type, description); }); -TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_support_level(level); }); -TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_num_inputs(n); }); -TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attrs_type_key(key); }); -TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpAttr") .set_body_typed([](String op_name, String attr_key, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties @@ -146,7 +146,7 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") } }); -TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") .set_body_typed([](String name, ffi::Function f, String target, int plevel) { tvm::OpRegEntry::RegisterOrGet(name).set_attr(target + ".FLowerIntrinsic", f, plevel); diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 44a6a22db7a0..0dca97302470 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -62,7 +62,7 @@ IRModule ReplaceGlobalVars(IRModule mod, Map replacements) return mod; } -TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); +TVM_FFI_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); IRModule ModuleReplaceGlobalVars( IRModule mod, Map, Variant> replacements) { @@ -93,7 +93,7 @@ IRModule ModuleReplaceGlobalVars( return ReplaceGlobalVars(mod, gvar_replacements); } -TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); +TVM_FFI_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); } // namespace transform } // namespace tvm diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 8e25b25a4ca4..482e1dfa1018 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -20,9 +20,9 @@ * \file source_map.cc * \brief The implementation of the source map data structure. */ +#include #include #include -#include #include @@ -50,7 +50,7 @@ ObjectPtr GetSourceNameNodeByStr(const std::string& name) { SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } -TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); +TVM_FFI_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -129,12 +129,12 @@ SequentialSpan::SequentialSpan(std::initializer_list init) { TVM_REGISTER_NODE_TYPE(SequentialSpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, - int column, int end_column) { +TVM_FFI_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, + int column, int end_column) { return Span(source_name, line, end_line, column, end_column); }); -TVM_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { +TVM_FFI_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { return SequentialSpan(spans); }); @@ -218,11 +218,12 @@ SourceMap::SourceMap(Map source_map) { void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } -TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { - auto src_name = SourceName::Get(name); - Source source(src_name, content); - map.Add(source); - return src_name; -}); +TVM_FFI_REGISTER_GLOBAL("SourceMapAdd") + .set_body_typed([](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 44dafcbf1d9e..db4e47ca0d1a 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -22,6 +22,7 @@ * \brief Infrastructure for transformation passes. */ #include +#include #include #include #include @@ -29,14 +30,12 @@ #include #include #include -#include #include #include #include #include -#include "../runtime/object_internal.h" #include "../runtime/regex.h" namespace tvm { @@ -532,12 +531,12 @@ Pass CreateModulePass(std::function pass_func, TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_GLOBAL("transform.PassInfo") +TVM_FFI_REGISTER_GLOBAL("transform.PassInfo") .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { return PassInfo(opt_level, name, required, traceable); }); -TVM_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { +TVM_FFI_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { Pass pass = args[0].cast(); *ret = pass->Info(); }); @@ -562,7 +561,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_GLOBAL("transform.MakeModulePass") +TVM_FFI_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( [](ffi::TypedFunction, PassContext)> pass_func, PassInfo pass_info) { @@ -572,7 +571,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") return ModulePass(wrapped_pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("transform.RunPass") +TVM_FFI_REGISTER_GLOBAL("transform.RunPass") .set_body_typed([](Pass pass, ffi::RValueRef mod) { return pass(*std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -585,7 +584,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential") +TVM_FFI_REGISTER_GLOBAL("transform.Sequential") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto passes = args[0].cast>(); int opt_level = args[1].cast(); @@ -612,7 +611,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext") +TVM_FFI_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, Optional> config, Array trace_stack, @@ -658,24 +657,27 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("transform.GetTraceStack").set_body_method(&PassContextNode::GetTraceStack); -TVM_REGISTER_GLOBAL("transform.PushTrace").set_body_method(&PassContextNode::PushTrace); -TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); -TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") +TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStack").set_body_method(&PassContextNode::GetTraceStack); +TVM_FFI_REGISTER_GLOBAL("transform.PushTrace").set_body_method(&PassContextNode::PushTrace); +TVM_FFI_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStackSize") .set_body_method(&PassContextNode::GetTraceStackSize); -TVM_REGISTER_GLOBAL("transform.GetCurrentTrace").set_body_method(&PassContextNode::GetCurrentTrace); -TVM_REGISTER_GLOBAL("transform.SetNumEvals").set_body_method(&PassContextNode::SetNumEvals); -TVM_REGISTER_GLOBAL("transform.IncNumEvals").set_body_method(&PassContextNode::IncNumEvals); -TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") +TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_FFI_REGISTER_GLOBAL("transform.SetNumEvals").set_body_method(&PassContextNode::SetNumEvals); +TVM_FFI_REGISTER_GLOBAL("transform.IncNumEvals").set_body_method(&PassContextNode::IncNumEvals); +TVM_FFI_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") .set_body_method(&PassContextNode::GetTuningAPIDatabase); -TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); +TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); +TVM_FFI_REGISTER_GLOBAL("transform.EnterPassContext") + .set_body_typed(PassContext::Internal::EnterScope); -TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); +TVM_FFI_REGISTER_GLOBAL("transform.ExitPassContext") + .set_body_typed(PassContext::Internal::ExitScope); -TVM_REGISTER_GLOBAL("transform.OverrideInstruments") +TVM_FFI_REGISTER_GLOBAL("transform.OverrideInstruments") .set_body_typed([](PassContext pass_ctx, Array instruments) { pass_ctx.InstrumentExitPassContext(); pass_ctx->instruments = instruments; @@ -690,9 +692,9 @@ Pass PrintIR(String header, bool show_meta_data) { return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } -TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); +TVM_FFI_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); -TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); +TVM_FFI_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 3c648418c6a9..8bc48a11141f 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -21,8 +21,8 @@ * \file src/ir/type.cc * \brief Common type system AST nodes throughout the IR. */ +#include #include -#include namespace tvm { PrimType::PrimType(runtime::DataType dtype, Span span) { @@ -34,7 +34,7 @@ PrimType::PrimType(runtime::DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { +TVM_FFI_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); @@ -47,7 +47,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_REGISTER_GLOBAL("ir.PointerType") +TVM_FFI_REGISTER_GLOBAL("ir.PointerType") .set_body_typed([](Type element_type, String storage_scope = "") { return PointerType(element_type, storage_scope); }); @@ -62,9 +62,10 @@ FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { TVM_REGISTER_NODE_TYPE(FuncTypeNode); -TVM_REGISTER_GLOBAL("ir.FuncType").set_body_typed([](tvm::Array arg_types, Type ret_type) { - return FuncType(arg_types, ret_type); -}); +TVM_FFI_REGISTER_GLOBAL("ir.FuncType") + .set_body_typed([](tvm::Array arg_types, Type ret_type) { + return FuncType(arg_types, ret_type); + }); TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); @@ -77,7 +78,7 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { +TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index bc0fef1ff4a7..58c4a7b33c4f 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -162,11 +162,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); TVM_REGISTER_NODE_TYPE(TensorInfoNode); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc").set_body_typed(ArgInfo::FromEntryFunc); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc") + .set_body_typed(ArgInfo::FromEntryFunc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TensorInfo") .set_body_typed([](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { return TensorInfo(dtype, shape); }); diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 9d725e91e247..85e189e73228 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -52,21 +52,21 @@ TVM_REGISTER_NODE_TYPE(BuilderResultNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_NODE_TYPE(PyBuilderNode); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderInput") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderInput") .set_body_typed([](IRModule mod, Target target, Optional> params) -> BuilderInput { return BuilderInput(mod, target, params); }); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderResult") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderResult") .set_body_typed([](Optional artifact_path, Optional error_msg) -> BuilderResult { return BuilderResult(artifact_path, error_msg); }); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 1d28eb19d7cb..5c1c7a568580 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -71,10 +71,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_NODE_TYPE(PyCostModelNode); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate").set_body_method(&CostModelNode::Update); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelUpdate").set_body_method(&CostModelNode::Update); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPredict") .set_body_typed([](CostModel model, // const TuneContext& context, // Array candidates, // @@ -82,7 +82,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") std::vector result = model->Predict(context, candidates); std::copy(result.begin(), result.end(), static_cast(p_addr)); }); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel") + .set_body_typed(CostModel::PyCostModel); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 9c23d9e845e6..034294eedcd3 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -282,43 +282,46 @@ TVM_REGISTER_NODE_TYPE(WorkloadNode); TVM_REGISTER_NODE_TYPE(TuningRecordNode); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); TVM_REGISTER_NODE_TYPE(PyDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { +TVM_FFI_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { return Workload(mod); }); -TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON").set_body_method(&WorkloadNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON").set_body_method(&WorkloadNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecord") .set_body_typed([](tir::Trace trace, Workload workload, Optional> run_secs, Optional target, Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") .set_body_method(&TuningRecordNode::AsMeasureCandidate); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON").set_body_method(&TuningRecordNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON") + .set_body_typed(TuningRecord::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") .set_body_method(&Database::EnterWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") .set_body_method(&Database::ExitWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") .set_body_method(&DatabaseNode::HasWorkload); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") .set_body_method(&DatabaseNode::GetAllTuningRecords); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") .set_body_method(&DatabaseNode::QueryTuningRecord); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") .set_body_method(&DatabaseNode::QuerySchedule); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") .set_body_method(&DatabaseNode::QueryIRModule); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned").set_body_method(&DatabaseNode::DumpPruned); -TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned") + .set_body_method(&DatabaseNode::DumpPruned); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index a59b334b221f..2a6b93f8cb3b 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -214,7 +214,8 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, } TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase") + .set_body_typed(Database::JSONDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 3d418206b031..cbc811752cad 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -97,7 +97,7 @@ Database Database::MemoryDatabase(String mod_eq_name) { } TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") .set_body_typed(Database::MemoryDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index df151764a6ed..87f5c03a71eb 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -79,7 +79,7 @@ Database Database::OrderedUnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") .set_body_typed(Database::OrderedUnionDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 0dbac9616c49..c66ec5f4f0c1 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -99,7 +99,7 @@ Database Database::ScheduleFnDatabase(ffi::TypedFunction sc } TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") .set_body_typed(Database::ScheduleFnDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 876affbdb2b9..2bc82b459cad 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -82,7 +82,8 @@ Database Database::UnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase").set_body_typed(Database::UnionDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase") + .set_body_typed(Database::UnionDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index ec04361f51ec..fb26e6eb693c 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -39,7 +39,7 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, } TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); -TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ExtractedTask") .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index 093558d2284e..9a3cecf4ce26 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -48,9 +48,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") .set_body_method(&FeatureExtractorNode::ExtractFrom); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") .set_body_typed(FeatureExtractor::PyFeatureExtractor); } // namespace meta_schedule diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 5d91bc34f9ae..2fc8878546d8 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -1442,7 +1442,7 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, } TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") .set_body_typed(FeatureExtractor::PerStoreFeature); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 68a4b93ea96f..becd9d2110df 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -65,7 +65,7 @@ MeasureCallback MeasureCallback::AddToDatabase() { } TVM_REGISTER_NODE_TYPE(AddToDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") .set_body_typed(MeasureCallback::AddToDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 8f94e298463a..0ee49f2ab4f9 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -59,11 +59,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") .set_body_method(&MeasureCallbackNode::Apply); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") .set_body_typed(MeasureCallback::PyMeasureCallback); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") .set_body_typed(MeasureCallback::Default); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index 9242e79912df..da74e85cac07 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -46,7 +46,7 @@ MeasureCallback MeasureCallback::RemoveBuildArtifact() { } TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") .set_body_typed(MeasureCallback::RemoveBuildArtifact); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 63c32b189eee..1969d7fc83a9 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -63,7 +63,7 @@ MeasureCallback MeasureCallback::UpdateCostModel() { } TVM_REGISTER_NODE_TYPE(UpdateCostModelNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") .set_body_typed(MeasureCallback::UpdateCostModel); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index f74d1640e475..8f8c077aa815 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -127,7 +127,7 @@ Mutator Mutator::MutateComputeLocation() { } TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") .set_body_typed(Mutator::MutateComputeLocation); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 6ddb6ea96bc9..a6a34e47a9d9 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -312,7 +312,8 @@ Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { } TVM_REGISTER_NODE_TYPE(MutateParallelNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel") + .set_body_typed(Mutator::MutateParallel); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 3666e84bd59f..269b05240443 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -165,7 +165,7 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") .set_body_typed(Mutator::MutateThreadBinding); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 35dc23ceba60..e8a728d05033 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -269,7 +269,8 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize") + .set_body_typed(Mutator::MutateTileSize); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 782da45a355c..28fcf3668f27 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -138,7 +138,7 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateUnrollNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index e1831d213d1e..e415b3909f10 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -88,20 +88,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(MutatorNode); TVM_REGISTER_NODE_TYPE(PyMutatorNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") .set_body_method(&MutatorNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorApply") .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); }); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") .set_body_typed(Mutator::DefaultCUDATensorCore); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon").set_body_typed(Mutator::DefaultHexagon); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon") + .set_body_typed(Mutator::DefaultHexagon); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index c32c7515facd..01a75a5bfb36 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -184,7 +184,7 @@ Postproc Postproc::DisallowAsyncStridedMemCopy() { } TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy") .set_body_typed(Postproc::DisallowAsyncStridedMemCopy); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 8362da552ea5..fd099ac5dd38 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -83,7 +83,7 @@ Postproc Postproc::DisallowDynamicLoop() { } TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") .set_body_typed(Postproc::DisallowDynamicLoop); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index b5d62634c23f..e29f9dd54c5a 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -112,16 +112,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(PostprocNode); TVM_REGISTER_NODE_TYPE(PyPostprocNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") .set_body_method(&PostprocNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") .set_body_typed(Postproc::DefaultCUDATensorCore); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") .set_body_typed(Postproc::DefaultHexagon); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index aa206764cce1..d23e07795cad 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -227,7 +227,7 @@ Postproc Postproc::RewriteCooperativeFetch() { } TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") .set_body_typed(Postproc::RewriteCooperativeFetch); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index cdfc6c56e549..84dc33ec98c8 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -273,7 +273,8 @@ Postproc Postproc::RewriteLayout() { } TVM_REGISTER_NODE_TYPE(RewriteLayoutNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout") + .set_body_typed(Postproc::RewriteLayout); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index c43c8aac12f8..3f665cd8d82a 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -464,7 +464,7 @@ Postproc Postproc::RewriteParallelVectorizeUnroll() { } TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 295044596ff9..3ffe0f9234d2 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -172,7 +172,7 @@ Postproc Postproc::RewriteReductionBlock() { } TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") .set_body_typed(Postproc::RewriteReductionBlock); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 4f8e0fb213f8..0f98484dd44e 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -107,7 +107,7 @@ Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { } TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") .set_body_typed(Postproc::RewriteTensorize); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 27ce34a8cb27..a2c9d1364ab6 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -146,7 +146,7 @@ Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { } TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") .set_body_typed(Postproc::RewriteUnboundBlock); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index f87e1ed9d50f..8ffc424e4451 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -215,7 +215,8 @@ Postproc Postproc::VerifyGPUCode() { } TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode") + .set_body_typed(Postproc::VerifyGPUCode); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index 4de975089653..7da2f8546b9e 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -69,7 +69,7 @@ Postproc Postproc::VerifyVTCMLimit() { } TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit") .set_body_typed(Postproc::VerifyVTCMLimit); } // namespace meta_schedule diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index b1386710b273..2a034a7be297 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -121,17 +121,17 @@ Optional Profiler::Current() { } TVM_REGISTER_NODE_TYPE(ProfilerNode); -TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { +TVM_FFI_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { return Profiler(); }); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope") .set_body_method(&Profiler::EnterWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope") .set_body_method(&Profiler::ExitWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method(&ProfilerNode::Get); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method(&ProfilerNode::Table); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method(&ProfilerNode::Get); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method(&ProfilerNode::Table); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 0d5edf299e43..38d4225f0fbd 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -56,24 +56,25 @@ TVM_REGISTER_NODE_TYPE(RunnerResultNode); TVM_REGISTER_NODE_TYPE(RunnerFutureNode); TVM_REGISTER_OBJECT_TYPE(RunnerNode); TVM_REGISTER_NODE_TYPE(PyRunnerNode); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerInput") .set_body_typed([](String artifact_path, String device_type, Array args_info) -> RunnerInput { return RunnerInput(artifact_path, device_type, args_info); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerResult") .set_body_typed([](Optional> run_secs, Optional error_msg) -> RunnerResult { return RunnerResult(run_secs, error_msg); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFuture") .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { return RunnerFuture(f_done, f_result); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone").set_body_method(&RunnerFutureNode::Done); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult").set_body_method(&RunnerFutureNode::Result); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone").set_body_method(&RunnerFutureNode::Done); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") + .set_body_method(&RunnerFutureNode::Result); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index 16e53b56923a..4e09fa729b3c 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -59,7 +59,7 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -71,14 +71,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV block) -> Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -90,7 +90,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV block) -> Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 46ef6f366310..c80141f5288d 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -63,7 +63,7 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -88,7 +88,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { GetWinogradProducerAndInlineConst(sch, inverse); ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); @@ -101,7 +101,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; @@ -132,7 +132,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { GetWinogradProducerAndInlineConst(sch, inverse); // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 2fc1352677cb..48149ed871e4 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -120,7 +120,7 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: } TVM_REGISTER_NODE_TYPE(AddRFactorNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") .set_body_typed(ScheduleRule::AddRFactor); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 9fdcaa4ef535..92de19163af5 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -87,7 +87,7 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { } TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") .set_body_typed(ScheduleRule::ApplyCustomRule); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index fa47d1edb860..892a79ea926d 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -82,7 +82,8 @@ ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_ } TVM_REGISTER_NODE_TYPE(AutoBindNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind").set_body_typed(ScheduleRule::AutoBind); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind") + .set_body_typed(ScheduleRule::AutoBind); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index d9e033eff810..948632e580e6 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -191,7 +191,7 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // } TVM_REGISTER_NODE_TYPE(AutoInlineNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); /*! \brief Inline blocks that produce a constant scalar. */ @@ -232,7 +232,7 @@ ScheduleRule ScheduleRule::InlineConstantScalars() { } TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") .set_body_typed(ScheduleRule::InlineConstantScalars); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index b71a14cc044e..e06817e37c4c 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -291,7 +291,7 @@ ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { } TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") .set_body_typed(ScheduleRule::CrossThreadReduction); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 79cff3bad738..f020c8efd08a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -405,7 +405,7 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, } TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") .set_body_typed(SearchStrategy::EvolutionarySearch); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation") .set_body_typed(EvolutionarySearchSampleInitPopulation); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel") .set_body_typed(EvolutionarySearchEvolveWithCostModel); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 60552b2e14f8..51cc40839195 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -157,7 +157,7 @@ SearchStrategy SearchStrategy::ReplayFunc() { } TVM_REGISTER_NODE_TYPE(ReplayFuncNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") .set_body_typed(SearchStrategy::ReplayFunc); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index f7960688cd80..c9a7459fdf61 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -191,7 +191,7 @@ SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 1bc71502ad36..8fc6538b59f5 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -86,23 +86,23 @@ TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") .set_body_typed([](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { return MeasureCandidate(sch, args_info.value_or({})); }); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") .set_body_typed(SearchStrategy::PySearchStrategy); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") .set_body_method(&SearchStrategyNode::PreTuning); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") .set_body_method(&SearchStrategyNode::PostTuning); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") .set_body_method(&SearchStrategyNode::NotifyRunnerResults); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") .set_body_method(&SearchStrategyNode::Clone); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index da2178c736a1..91d5ba53d551 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -116,7 +116,7 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, } TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") .set_body_typed(SpaceGenerator::PostOrderApply); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 89a02876f3d9..f7f2a3ba19de 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -97,7 +97,7 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, } TVM_REGISTER_NODE_TYPE(ScheduleFnNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") .set_body_typed(SpaceGenerator::ScheduleFn); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 8712f5ad4892..7306fffcb1af 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -190,13 +190,13 @@ SpaceGenerator SpaceGenerator::PySpaceGenerator( TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") .set_body_method(&SpaceGeneratorNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") .set_body_typed(SpaceGenerator::PySpaceGenerator); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") .set_body_method(&SpaceGeneratorNode::Clone); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 819a4ee5f795..12bf75349430 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -82,7 +82,7 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g } TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index c750067ace9f..23d23e624394 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -145,7 +145,7 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i } TVM_REGISTER_NODE_TYPE(GradientBasedNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") .set_body_typed(TaskScheduler::GradientBased); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index d7c6f37e121d..9792fa7e7c25 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -63,7 +63,7 @@ TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { } TVM_REGISTER_NODE_TYPE(RoundRobinNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") .set_body_typed(TaskScheduler::RoundRobin); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 4411bd0792e0..85a406365377 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -364,18 +364,19 @@ void PyTaskSchedulerNode::Tune(Array tasks, Array task_we TVM_REGISTER_NODE_TYPE(TaskRecordNode); TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") .set_body_typed(TaskScheduler::PyTaskScheduler); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune").set_body_method(&TaskSchedulerNode::Tune); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") + .set_body_method(&TaskSchedulerNode::Tune); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") .set_body_method(&TaskSchedulerNode::NextTaskId); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") .set_body_method(&TaskSchedulerNode::TerminateTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") .set_body_method(&TaskSchedulerNode::TouchTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") .set_body_method(&TaskSchedulerNode::PrintTuningStatistics); } // namespace meta_schedule diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 5ba3f3123cbb..9d22554d912f 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -254,7 +254,7 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm } } -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") .set_body_typed(ScheduleUsingAnchorTrace); } // namespace meta_schedule diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 275f8d124cd1..31120ce45d4a 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -63,7 +63,7 @@ void TuneContextNode::Initialize() { } TVM_REGISTER_NODE_TYPE(TuneContextNode); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, Optional target, Optional space_generator, Optional search_strategy, Optional task_name, @@ -72,10 +72,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, rand_state, logger); }); -TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") +TVM_FFI_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") .set_body_method(&TuneContextNode::Initialize); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClone").set_body_method(&TuneContextNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextClone").set_body_method(&TuneContextNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index 5fb3d503f55a..7441db783296 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -21,9 +21,9 @@ * Printer implementation for containers * \file node/container_printint.cc */ +#include #include #include -#include namespace tvm { diff --git a/src/node/object_path.cc b/src/node/object_path.cc index bfb05aa4bc60..a99835ea17ad 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -17,10 +17,10 @@ * under the License. */ +#include #include #include #include -#include #include #include @@ -40,13 +40,13 @@ Optional ObjectPathNode::GetParent() const { return Downcast>(parent_); } -TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_method(&ObjectPathNode::GetParent); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_method(&ObjectPathNode::GetParent); // --- Length --- int32_t ObjectPathNode::Length() const { return length_; } -TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); // --- GetPrefix --- @@ -63,7 +63,7 @@ ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { return GetRef(node); } -TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix").set_body_method(&ObjectPathNode::GetPrefix); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetPrefix").set_body_method(&ObjectPathNode::GetPrefix); // --- IsPrefixOf --- @@ -75,7 +75,7 @@ bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { return this->PathsEqual(other->GetPrefix(this_len)); } -TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf").set_body_method(&ObjectPathNode::IsPrefixOf); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf").set_body_method(&ObjectPathNode::IsPrefixOf); // --- Attr --- @@ -95,7 +95,7 @@ ObjectPath ObjectPathNode::Attr(Optional attr_key) const { } } -TVM_REGISTER_GLOBAL("node.ObjectPathAttr") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathAttr") .set_body_typed([](const ObjectPath& object_path, Optional attr_key) { return object_path->Attr(attr_key); }); @@ -106,7 +106,7 @@ ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex").set_body_method(&ObjectPathNode::ArrayIndex); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathArrayIndex").set_body_method(&ObjectPathNode::ArrayIndex); // --- MissingArrayElement --- @@ -114,7 +114,7 @@ ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") .set_body_method(&ObjectPathNode::MissingArrayElement); // --- MapValue --- @@ -123,7 +123,7 @@ ObjectPath ObjectPathNode::MapValue(Any key) const { return ObjectPath(make_object(this, std::move(key))); } -TVM_REGISTER_GLOBAL("node.ObjectPathMapValue").set_body_method(&ObjectPathNode::MapValue); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMapValue").set_body_method(&ObjectPathNode::MapValue); // --- MissingMapEntry --- @@ -131,7 +131,7 @@ ObjectPath ObjectPathNode::MissingMapEntry() const { return ObjectPath(make_object(this)); } -TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") .set_body_method(&ObjectPathNode::MissingMapEntry); // --- PathsEqual ---- @@ -158,7 +158,7 @@ bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { return lhs == nullptr && rhs == nullptr; } -TVM_REGISTER_GLOBAL("node.ObjectPathEqual").set_body_method(&ObjectPathNode::PathsEqual); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathEqual").set_body_method(&ObjectPathNode::PathsEqual); // --- Repr --- @@ -191,7 +191,7 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const { return ObjectPath(make_object(name)); } -TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); // ============== Individual path classes ============== diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 6ba6a4ef1f18..2290403d3730 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -21,10 +21,10 @@ * Reflection utilities. * \file node/reflection.cc */ +#include #include #include #include -#include namespace tvm { @@ -292,11 +292,11 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1)); } -TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); +TVM_FFI_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); -TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNames); +TVM_FFI_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNames); -TVM_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); +TVM_FFI_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); namespace { // Attribute visitor class for finding the attribute key by its address diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 3e80751e6604..aa999655c03d 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -21,9 +21,9 @@ * Printer utilities * \file node/repr_printer.cc */ +#include #include #include -#include namespace tvm { @@ -133,12 +133,12 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { +TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { std::ostringstream os; os << obj; return os.str(); }); -TVM_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); +TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index dd249b77dbae..ee7880f4485a 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -16,10 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include -#include #include @@ -135,9 +135,9 @@ Array PrinterConfigNode::GetBuiltinKeywords() { } TVM_REGISTER_NODE_TYPE(PrinterConfigNode); -TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { +TVM_FFI_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { return PrinterConfig(config_dict); }); -TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); +TVM_FFI_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 7f74012c0de5..631d70e2356c 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -23,18 +23,17 @@ */ #include #include +#include #include #include #include #include #include -#include #include #include #include -#include "../runtime/object_internal.h" #include "../support/base64.h" namespace tvm { @@ -701,7 +700,7 @@ Any LoadJSON(std::string json_str) { return nodes.at(jgraph.root); } -TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); +TVM_FFI_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); -TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); +TVM_FFI_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 17a96b5e92a2..6b19fb5355bb 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,13 +19,13 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include #include #include #include -#include #include #include @@ -36,12 +36,12 @@ namespace tvm { TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); -TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairLhsPath") .set_body_typed([](const ObjectPathPair& object_path_pair) { return object_path_pair->lhs_path; }); -TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairRhsPath") .set_body_typed([](const ObjectPathPair& object_path_pair) { return object_path_pair->rhs_path; }); @@ -595,7 +595,7 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); } -TVM_REGISTER_GLOBAL("node.StructuralEqual") +TVM_FFI_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { // If we are asserting on failure, then the `defer_fails` option // should be enabled, to provide better error messages. For @@ -608,7 +608,7 @@ TVM_REGISTER_GLOBAL("node.StructuralEqual") .Equal(lhs, rhs, map_free_vars); }); -TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") +TVM_FFI_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") .set_body_typed([](const Any& lhs, const Any& rhs, bool map_free_vars) { Optional first_mismatch; bool equal = diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 14c7b8a39a91..efaa7037b013 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -20,13 +20,13 @@ * \file src/node/structural_hash.cc */ #include +#include #include #include #include #include #include #include -#include #include #include @@ -291,7 +291,7 @@ void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars impl->DispatchSHash(key, map_free_vars); } -TVM_REGISTER_GLOBAL("node.StructuralHash") +TVM_FFI_REGISTER_GLOBAL("node.StructuralHash") .set_body_typed([](const Any& object, bool map_free_vars) -> int64_t { uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); return static_cast(hashed_value); diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index b4a5a039bbaa..98122d1e1ec8 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -197,15 +197,15 @@ bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { return FindImpureCall(expr, own_name).defined(); } -TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); -TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); -TVM_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 37bbf3a9775e..ba163b51d6c9 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -92,7 +92,7 @@ Array ComputableAtCompileTime(const Function& func) { return CompileTimeCollector::Collect(func); } -TVM_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") .set_body_typed(ComputableAtCompileTime); } // namespace relax diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 9c150fed8bfd..48ec7880b172 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -392,7 +392,7 @@ tvm::Array> DetectRecursion(const IRModule& m) { return ret; } -TVM_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index f0658dabb398..ab32abab5bea 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -614,7 +614,7 @@ Map> SuggestLayoutTransforms( return analyzer.GetSuggestedTransforms(); } -TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { return SuggestLayoutTransforms(fn, write_buffer_transformations); }); diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 5ad8a173c0b4..e09f061001f9 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -72,7 +72,7 @@ class StaticTypeDeriver : public StructInfoFunctor { Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } -TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { +TVM_FFI_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { return GetStaticType(info); }); @@ -285,7 +285,7 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); } -TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") .set_body_typed([](const StructInfo& info, Map shape_var_map, Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); @@ -595,7 +595,7 @@ BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& de } } -TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { return static_cast(StructInfoBaseCheck(base, derived)); }); @@ -604,7 +604,7 @@ bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } -TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") +TVM_FFI_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") .set_body_typed([](const StructInfo& base, const StructInfo& derived) { return IsBaseOf(base, derived); }); @@ -955,7 +955,7 @@ StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call } } -TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { return DeriveCallRetStructInfo(finfo, call, ctx); }); @@ -1158,7 +1158,7 @@ StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::An } } -TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { return StructInfoLCA(lhs, rhs); }); @@ -1241,9 +1241,9 @@ Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { return detector.GetTIRVars(); } -TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo); -TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") .set_body_typed(DefinableTIRVarsInStructInfo); class NonNegativeExpressionCollector : relax::StructInfoVisitor { @@ -1288,7 +1288,7 @@ Array CollectNonNegativeExpressions(const StructInfo& sinfo) { return NonNegativeExpressionCollector::Collect(sinfo); } -TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") .set_body_typed(CollectNonNegativeExpressions); class SymbolicVarCollector : public relax::ExprVisitor, @@ -1436,9 +1436,9 @@ Array DefinedSymbolicVars(const Expr& expr) { } Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } -TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars); -TVM_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 177fdeff588c..0845ec092fe2 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -537,7 +537,7 @@ bool HasReshapePattern(const PrimFunc& func) { return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } -TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 1c16f25c997f..f62254b6959d 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -118,7 +118,7 @@ Map> DataflowBlockUseDef(const DataflowBlock& dfb) { return usage.downstream_usage; } -TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index 7625841eb69c..a367d33ca4ff 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -58,7 +58,7 @@ Map AnalyzeVar2Value(const IRModule& m) { return std::move(var2val_analysis.var2value_); } -TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { return AnalyzeVar2Value(f); }); @@ -85,7 +85,7 @@ Map> NameToBinding(const Function& fn) { std::make_move_iterator(analysis.name2bindings_.end())); } -TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 316e87f351fa..243033e9454b 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -646,7 +646,7 @@ bool WellFormed(Variant obj, bool check_struct_info) { return WellFormedChecker::Check(obj, check_struct_info); } -TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed); +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed); } // namespace relax } // namespace tvm diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 3c87079f99d5..ff64504f6111 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -322,7 +322,7 @@ Array OpenCLMLCompiler(Array functions, Map CublasCompiler(Array functions, Map cuDNNCompiler(Array functions, Map headers) { return CodegenResult(code, headers); }); @@ -385,7 +385,7 @@ Array CUTLASSCompiler(Array functions, Map DNNLCompiler(Array functions, Map HipblasCompiler(Array functions, Map #include #include #include #include -#include #include #include @@ -264,7 +264,7 @@ Array NNAPICompiler(Array functions, Map TensorRTCompiler(Array functions, Map GetTensorRTVersion() { #endif // TVM_GRAPH_EXECUTOR_TENSORRT } -TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled); -TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); +TVM_FFI_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled") + .set_body_typed(IsTensorRTRuntimeEnabled); +TVM_FFI_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index 8e214809dd51..6574ccc37a15 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -75,7 +75,7 @@ bool EndsWithPattern(const std::string& str, const std::string& pattern) { return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; } -TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); +TVM_FFI_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); } // namespace backend } // namespace relax diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 9feeca6662dc..840b44c12838 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -67,10 +67,11 @@ Optional GetPattern(const String& pattern_name) { return std::nullopt; } -TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); -TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); -TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); -TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); +TVM_FFI_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); +TVM_FFI_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); +TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix") + .set_body_typed(GetPatternsWithPrefix); +TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); } // namespace backend } // namespace relax diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index af6e83cf5f9a..686d24de62b2 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -139,7 +139,7 @@ class TaskExtractor : public ExprVisitor { std::optional normalize_mod_func_; }; -TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") +TVM_FFI_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") .set_body_typed([](IRModule mod, Target target, String mod_eq_name) { return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); }); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 44acd6ea81c2..f61579e25e96 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -425,7 +425,7 @@ IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVM::Run(exec_builder, mod); } -TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); +TVM_FFI_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); /*! * \brief Link the modules together, possibly create a constant module. @@ -490,7 +490,7 @@ Module VMLink(ExecBuilder builder, Target target, Optional lib, Array(); ffi::Any rt; @@ -337,20 +337,21 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") *ret = builder->ConvertConstant(rt).data(); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") .set_body_typed([](ExecBuilder builder, String func, int64_t num_inputs, Optional> param_names) { builder->EmitFunction(func, num_inputs, param_names); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEndFunction").set_body_method(&ExecBuilderNode::EndFunction); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEndFunction") + .set_body_method(&ExecBuilderNode::EndFunction); -TVM_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") .set_body_typed([](ExecBuilder builder, String name, int32_t kind) { builder->DeclareFunction(name, static_cast(kind)); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { std::vector args_; for (size_t i = 0; i < args.size(); ++i) { @@ -360,35 +361,38 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") builder->EmitCall(name, args_, dst_.value()); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") .set_body_typed([](ExecBuilder builder, int64_t data) { builder->EmitRet(Instruction::Arg::FromData(data)); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto").set_body_method(&ExecBuilderNode::EmitGoto); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto").set_body_method(&ExecBuilderNode::EmitGoto); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") .set_body_typed([](ExecBuilder builder, int64_t data, vm::Index false_offset) { builder->EmitIf(Instruction::Arg::FromData(data), false_offset); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Register(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderR") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Register(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Immediate(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderImm") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Immediate(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::ConstIdx(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderC") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::ConstIdx(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); return runtime::Module(p_exec); }); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index f746e4a5afd2..7757195bcb1d 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -231,7 +231,7 @@ Pass LowerRuntimeBuiltin() { return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 6400a3332133..0b60553034fe 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -813,7 +813,7 @@ Pass VMShapeLower(bool emit_err_ctx) { return CreateModulePass(pass_func, 0, "VMShapeLower", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { return VMShapeLower(emit_err_ctx); }); diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index c0d6204c8739..e1cc32fc82e3 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -57,7 +57,7 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { } TVM_REGISTER_NODE_TYPE(DeviceMeshNode); -TVM_REGISTER_GLOBAL("relax.distributed.DeviceMesh") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.DeviceMesh") .set_body_typed([](ffi::Shape shape, Array device_ids, Optional device_range) { if (device_range.defined()) return DeviceMesh(shape, device_range.value()); diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 3569b1538551..0ff9d4d6fa09 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -43,11 +43,11 @@ PlacementSpec PlacementSpec::Replica() { TVM_REGISTER_NODE_TYPE(PlacementSpecNode); -TVM_REGISTER_GLOBAL("relax.distributed.Sharding").set_body_typed([](int axis) { +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Sharding").set_body_typed([](int axis) { return PlacementSpec::Sharding(axis); }); -TVM_REGISTER_GLOBAL("relax.distributed.Replica").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Replica").set_body_typed([]() { return PlacementSpec::Replica(); }); @@ -106,8 +106,8 @@ Placement Placement::FromText(String text_repr) { } TVM_REGISTER_NODE_TYPE(PlacementNode); -TVM_REGISTER_GLOBAL("relax.distributed.PlacementFromText").set_body_typed(Placement::FromText); -TVM_REGISTER_GLOBAL("relax.distributed.Placement") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.PlacementFromText").set_body_typed(Placement::FromText); +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Placement") .set_body_typed([](Array dim_specs) { return Placement(dim_specs); }); // DTensor @@ -130,7 +130,7 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d TVM_REGISTER_NODE_TYPE(DTensorStructInfoNode); -TVM_REGISTER_GLOBAL("relax.distributed.DTensorStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.DTensorStructInfo") .set_body_typed([](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 5ed947858775..1df1d2110ba9 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -115,7 +115,7 @@ Pass LegalizeRedistribute() { }; return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute") .set_body_typed(LegalizeRedistribute); } // namespace transform diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 59de65450e9e..e4f811b83d42 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -262,7 +262,7 @@ Pass LowerDistIR() { auto pass_func = [=](IRModule m, PassContext pc) { return DistIRSharder::LowerDistIR(m); }; return CreateModulePass(pass_func, 1, "LowerDistIR", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR); +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR); } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 514a98ef44f3..c8abe2b1d1b5 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -432,7 +432,7 @@ Pass LowerGlobalViewToLocalView() { auto pass_func = [=](IRModule m, PassContext pc) { return LowerTIRToLocalView(m).Lower(); }; return CreateModulePass(pass_func, 1, "LowerGlobalViewToLocalView", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LowerGlobalViewToLocalView") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerGlobalViewToLocalView") .set_body_typed(LowerGlobalViewToLocalView); } // namespace transform diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 15b372fb8348..f5f276c2b873 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -615,7 +615,7 @@ Pass PropagateSharding() { }; return CreateModulePass(pass_func, 1, "PropagateSharding", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.PropagateSharding") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.PropagateSharding") .set_body_typed(PropagateSharding); } // namespace transform diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 8a70f1832440..f35b443b5b39 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -51,7 +51,7 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlockRewrite") .set_body_typed([](DataflowBlock dfb, Function root_fn) { return DataflowBlockRewrite(dfb, root_fn); }); @@ -110,7 +110,7 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { } } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { rwt->ReplaceAllUses(old_var, new_var); }); @@ -178,10 +178,10 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add") .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { if (name.has_value()) { rwt->Add(name.value(), expr, is_dfvar); @@ -292,7 +292,7 @@ void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { to_users_.erase(unused); // update use-def chain. } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { rwt->RemoveUnused(unused, allow_undef); }); @@ -314,7 +314,7 @@ void DataflowBlockRewriteNode::RemoveAllUnused() { for (const auto& unused : remover.unused_vars) to_users_.erase(unused); } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); Expr RemoveAllUnused(Expr expr) { @@ -333,7 +333,7 @@ Expr RemoveAllUnused(Expr expr) { return remover.VisitExpr(std::move(expr)); } -TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { BlockBuilder builder = BlockBuilder::Create(irmod); @@ -348,7 +348,7 @@ IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { return builder->GetContextIRModule(); } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { return rwt->MutateIRModule(irmod); }); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index db173fe9d069..63288201e741 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -21,6 +21,7 @@ * \file src/relax/block_builder.cc */ #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include #include @@ -1054,65 +1054,67 @@ BlockBuilder BlockBuilder::Create(Optional mod, //--------------------------------------- TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); -TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { return BlockBuilder::Create(mod); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") .set_body_method(&BlockBuilderNode::BeginDataflowBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") .set_body_method(&BlockBuilderNode::BeginBindingBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_method(&BlockBuilderNode::EndBlock); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_method(&BlockBuilderNode::EndBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize").set_body_method(&BlockBuilderNode::Normalize); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmit") .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) { return builder->Emit(expr, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { return builder->EmitMatchCast(value, struct_info, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) { return builder->EmitOutput(output, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") .set_body_typed([](BlockBuilder builder, Binding binding) { return builder->EmitNormalized(binding); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") .set_body_typed([](BlockBuilder builder, String name_hint) { return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, /*add_underscore*/ false); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") .set_body_method(&BlockBuilderNode::AddFunction); -TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") .set_body_method(&BlockBuilderNode::UpdateFunction); -TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") .set_body_method(&BlockBuilderNode::GetContextIRModule); -TVM_REGISTER_GLOBAL("relax.BlockBuilderFinalize").set_body_method(&BlockBuilderNode::Finalize); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderFinalize").set_body_method(&BlockBuilderNode::Finalize); -TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); -TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") .set_body_method(&BlockBuilderNode::LookupBinding); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope").set_body_method(&BlockBuilderNode::BeginScope); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope").set_body_method(&BlockBuilderNode::EndScope); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndScope").set_body_method(&BlockBuilderNode::EndScope); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index b82269f22813..172f4d7bcb27 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -362,7 +362,7 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_dfb") .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { return MatchGraph(ctx, dfb); }); @@ -447,7 +447,7 @@ Function RewriteBindings( return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 1c5aa4a0cc37..c398305d938c 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -193,17 +193,16 @@ void RewriteSpec::Append(RewriteSpec other) { TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") .set_body_typed([](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { return PatternMatchingRewriter::FromPattern(pattern, func); }); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule").set_body_typed([](IRModule mod) { - return PatternMatchingRewriter::FromModule(mod); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule") + .set_body_typed([](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") .set_body_typed([](PatternMatchingRewriter rewriter, Variant obj) -> Variant { if (auto expr = obj.as()) { @@ -259,7 +258,7 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, return std::nullopt; } -TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternRewriter") .set_body_typed([](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { return ExprPatternRewriter(pattern, func); @@ -308,7 +307,7 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons return lhs_match; } -TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrRewriter") .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { return OrRewriter(lhs, rhs); }); @@ -603,7 +602,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( return rewrites; } -TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleRewriter") .set_body_typed([](Array patterns, ffi::TypedFunction(Expr, Map)> func) { return TupleRewriter(patterns, func); @@ -796,13 +795,13 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return matcher.GetMemo(); } -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); /*! * \brief Apply pattern matching to each expression, replacing @@ -1074,7 +1073,7 @@ Function RewriteCall(const DFPattern& pat, return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 7f045ebe486e..db242b773be6 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -44,7 +44,7 @@ ExternFuncPattern::ExternFuncPattern(String global_symbol) { n->global_symbol_ = std::move(global_symbol); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { return ExternFuncPattern(global_symbol); }); RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { @@ -57,7 +57,7 @@ VarPattern::VarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { return VarPattern(name_hint); }); RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { @@ -65,7 +65,7 @@ RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { }); TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); -TVM_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { return DataflowVarPattern(name_hint); }); DataflowVarPattern::DataflowVarPattern(String name_hint) { @@ -83,7 +83,7 @@ GlobalVarPattern::GlobalVarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { return GlobalVarPattern(name_hint); }); RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { @@ -96,11 +96,13 @@ ExprPattern::ExprPattern(Expr expr) { n->expr = std::move(expr); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { return ExprPattern(e); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { + return ExprPattern(e); +}); RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); TVM_REGISTER_NODE_TYPE(ConstantPatternNode); -TVM_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { auto c = ConstantPattern(make_object()); return c; }); @@ -115,7 +117,7 @@ CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_ n->varg_default_wildcard = varg_default_wildcard; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.CallPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.CallPattern") .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { return CallPattern(op, args, varg_default_wildcard); }); @@ -138,7 +140,7 @@ PrimArrPattern::PrimArrPattern(Array arr) { n->fields = std::move(arr); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { return PrimArrPattern(std::move(arr)); }); RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { @@ -152,7 +154,7 @@ FunctionPattern::FunctionPattern(Array params, DFPattern body) { n->body = std::move(body); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.FunctionPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.FunctionPattern") .set_body_typed([](Array params, DFPattern body) { return FunctionPattern(params, body); }); @@ -166,7 +168,7 @@ TuplePattern::TuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { @@ -179,7 +181,7 @@ UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { p->stream << "UnorderedTuplePattern(" << node->fields << ")"; @@ -192,9 +194,8 @@ TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { n->index = index; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern").set_body_typed([](DFPattern tuple, int index) { - return TupleGetItemPattern(tuple, index); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern") + .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; }); @@ -206,7 +207,7 @@ AndPattern::AndPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { return AndPattern(left, right); }); RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { @@ -220,7 +221,7 @@ OrPattern::OrPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { return OrPattern(left, right); }); RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { @@ -233,7 +234,7 @@ NotPattern::NotPattern(DFPattern reject) { n->reject = std::move(reject); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { return NotPattern(reject); }); RELAX_PATTERN_PRINTER_DEF(NotPatternNode, @@ -241,7 +242,9 @@ RELAX_PATTERN_PRINTER_DEF(NotPatternNode, TVM_REGISTER_NODE_TYPE(WildcardPatternNode); WildcardPattern::WildcardPattern() { data_ = make_object(); } -TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { return WildcardPattern(); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { + return WildcardPattern(); +}); RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); TVM_REGISTER_NODE_TYPE(TypePatternNode); @@ -251,7 +254,7 @@ TypePattern::TypePattern(DFPattern pattern, Type type) { n->type = std::move(type); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { @@ -265,7 +268,7 @@ StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) n->struct_info = std::move(struct_info); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.StructInfoPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.StructInfoPattern") .set_body_typed([](DFPattern pattern, StructInfo struct_info) { return StructInfoPattern(pattern, struct_info); }); @@ -281,7 +284,7 @@ ShapePattern::ShapePattern(DFPattern pattern, Array shape) { n->shape = std::move(shape); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ShapePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ShapePattern") .set_body_typed([](DFPattern pattern, Array shape) { return ShapePattern(pattern, shape); }); @@ -299,7 +302,7 @@ SameShapeConstraint::SameShapeConstraint(Array args) { ctx.value().add_constraint(*this); } } -TVM_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array args) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array args) { return SameShapeConstraint(args); }); RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { @@ -320,7 +323,7 @@ DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { n->dtype = std::move(dtype); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.DataTypePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataTypePattern") .set_body_typed([](DFPattern pattern, DataType dtype) { return DataTypePattern(pattern, dtype); }); @@ -335,9 +338,8 @@ AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { n->attrs = std::move(attrs); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.AttrPattern").set_body_typed([](DFPattern pattern, DictAttrs attrs) { - return AttrPattern(pattern, attrs); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.AttrPattern") + .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; }); @@ -511,7 +513,7 @@ PatternSeq PatternSeq::dup() const { return ret; } -TVM_REGISTER_GLOBAL("relax.dpl.PatternSeq") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternSeq") .set_body_typed([](Array patterns, bool only_used_by) { return PatternSeq(std::move(patterns), only_used_by); }); @@ -525,12 +527,12 @@ RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "]"; }); -TVM_REGISTER_GLOBAL("relax.dpl.used_by") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.used_by") .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.UsedBy(rhs, index); }); -TVM_REGISTER_GLOBAL("relax.dpl.only_used_by") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.only_used_by") .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.OnlyUsedBy(rhs, index); }); @@ -643,25 +645,27 @@ DFPattern DFPattern::dup() const { return pattern; } -TVM_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { return pattern.dup(); }); -TVM_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { return seq.dup(); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { + return seq.dup(); +}); -TVM_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { return PatternContext(incre); }); -TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { return PatternContext::Current(); }); -TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { ctx.EnterWithScope(); }); -TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { ctx.ExitWithScope(); }); diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 86c38d08a1c1..e75dc3c2d7ca 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -72,7 +72,7 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string return te::PlaceholderOp(n).output(0); } -TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); +TVM_FFI_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 4981e9b9f850..238cece41f61 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -97,7 +97,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_GLOBAL("relax.Call") +TVM_FFI_REGISTER_GLOBAL("relax.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { return Call(op, args, attrs, sinfo_args, span); }); @@ -132,7 +132,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_GLOBAL("relax.If") +TVM_FFI_REGISTER_GLOBAL("relax.If") .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { return If(cond, true_branch, false_branch, span); }); @@ -162,7 +162,7 @@ Tuple::Tuple(tvm::Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); @@ -226,7 +226,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { return TupleGetItem(tuple, index, span); }); @@ -249,7 +249,7 @@ ShapeExpr::ShapeExpr(Array values, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { return ShapeExpr(values, span); }); @@ -285,12 +285,12 @@ VarNode* Var::CopyOnWrite() { return static_cast(data_.get()); } -TVM_REGISTER_GLOBAL("relax.Var") +TVM_FFI_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); }); -TVM_REGISTER_GLOBAL("relax.VarFromId") +TVM_FFI_REGISTER_GLOBAL("relax.VarFromId") .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { return Var(vid, struct_info_annotation, span); }); @@ -309,12 +309,12 @@ DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Sp data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowVar") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowVar") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return DataflowVar(name_hint, struct_info_annotation, span); }); -TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowVarFromId") .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { return DataflowVar(vid, struct_info_annotation, span); }); @@ -344,7 +344,7 @@ Constant::Constant(runtime::NDArray data, Optional struct_info_annot TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relax.Constant") +TVM_FFI_REGISTER_GLOBAL("relax.Constant") .set_body_typed([](runtime::NDArray data, Optional struct_info_annotation = std::nullopt, Span span = Span()) { @@ -366,7 +366,7 @@ PrimValue PrimValue::Int64(int64_t value, Span span) { TVM_REGISTER_NODE_TYPE(PrimValueNode); -TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { return PrimValue(value, span); }); @@ -383,7 +383,7 @@ StringImm::StringImm(String value, Span span) { TVM_REGISTER_NODE_TYPE(StringImmNode); -TVM_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { return StringImm(value, span); }); @@ -400,7 +400,7 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { TVM_REGISTER_NODE_TYPE(DataTypeImmNode); -TVM_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { return DataTypeImm(value, span); }); @@ -416,7 +416,7 @@ MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.MatchCast") +TVM_FFI_REGISTER_GLOBAL("relax.MatchCast") .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { return MatchCast(var, value, struct_info, span); }); @@ -458,7 +458,7 @@ VarBinding::VarBinding(Var var, Expr value, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { return VarBinding(var, value, span); }); @@ -513,9 +513,10 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { return static_cast(data_.get()); } -TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { - return BindingBlock(bindings, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.BindingBlock") + .set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); + }); TVM_REGISTER_NODE_TYPE(DataflowBlockNode); @@ -526,9 +527,10 @@ DataflowBlock::DataflowBlock(Array bindings, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { - return DataflowBlock(bindings, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlock") + .set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); + }); TVM_REGISTER_NODE_TYPE(SeqExprNode); @@ -548,7 +550,7 @@ SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.SeqExpr") +TVM_FFI_REGISTER_GLOBAL("relax.SeqExpr") .set_body_typed([](Array blocks, Expr body, Span span) { return SeqExpr(blocks, body, span); }); @@ -624,7 +626,7 @@ Function::Function(Array params, Expr body, Optional ret_struct data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.Function") +TVM_FFI_REGISTER_GLOBAL("relax.Function") .set_body_typed([](Array params, Expr body, Optional ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { return Function(params, body, ret_struct_info, is_pure, attrs, span); @@ -662,7 +664,7 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo return Function(std::move(n)); } -TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") +TVM_FFI_REGISTER_GLOBAL("relax.FunctionCreateEmpty") .set_body_typed([](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); @@ -670,7 +672,7 @@ TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") // Special opaque derivation function for ExternFunc // Take look at sinfo_args to figure out the return StructInfo. -TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") +TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; if (call->sinfo_args.empty()) { @@ -708,7 +710,7 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc") +TVM_FFI_REGISTER_GLOBAL("relax.ExternFunc") .set_body_typed([](String global_symbol, Optional struct_info, Span span) { if (struct_info.defined()) { return ExternFunc(global_symbol, struct_info.value(), span); @@ -732,11 +734,11 @@ Expr GetShapeOf(const Expr& expr) { return call_shape_of; } -TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { +TVM_FFI_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { return GetShapeOf(expr); }); -TVM_REGISTER_GLOBAL("relax.FuncWithAttr") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); @@ -744,7 +746,7 @@ TVM_REGISTER_GLOBAL("relax.FuncWithAttr") return std::nullopt; }); -TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttrs") .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); @@ -752,7 +754,7 @@ TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") return std::nullopt; }); -TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> Optional { if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 0b497a97de92..5e04453a1227 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -326,7 +326,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.post_order_visit") .set_body_typed([](Expr expr, ffi::Function f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 4a36bf214884..dc355cef905f 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -540,30 +540,30 @@ class PyExprMutator : public ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); }; -TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); +TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { visitor->VisitBinding(binding); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { visitor->VisitBindingBlock(block); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->ExprVisitor::VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { if (const auto* ptr = binding.as()) { visitor->ExprVisitor::VisitBinding_(ptr); @@ -574,7 +574,7 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { if (const auto* ptr = block.as()) { visitor->ExprVisitor::VisitBindingBlock_(ptr); @@ -585,7 +585,7 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") .set_body_typed([](PyExprVisitor visitor, const Var& var) { if (const auto* node = var.as()) { visitor->ExprVisitor::VisitVarDef_(node); @@ -596,39 +596,39 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") .set_body_typed([](PyExprVisitor visitor, const Span& span) { visitor->ExprVisitor::VisitSpan(span); }); -TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); +TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") .set_body_typed([](PyExprMutator mutator, const Binding& binding) { mutator->VisitBinding(binding); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { return mutator->VisitBindingBlock(block); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") .set_body_typed([](PyExprMutator mutator, const Var& var) { return mutator->VisitVarDef(var); }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->ExprMutator::VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") .set_body_typed([](PyExprMutator mutator, const Binding& binding) { if (const auto* ptr = binding.as()) { return mutator->ExprMutator::VisitBinding_(ptr); @@ -639,7 +639,7 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") } }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { if (const auto* node = block.as()) { return mutator->ExprMutator::VisitBindingBlock_(node); @@ -650,7 +650,7 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") } }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") .set_body_typed([](PyExprMutator mutator, const Var& var) { if (const auto* node = var.as()) { return mutator->ExprMutator::VisitVarDef_(node); @@ -661,32 +661,32 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") } }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExprPostOrder(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitWithNewScope(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") .set_body_typed([](PyExprMutator mutator, const Var& var) { return mutator->LookupBinding(var); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { return mutator->WithStructInfo(var, sinfo); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") .set_body_typed([](PyExprMutator mutator, Id id, Var var) { return mutator->var_remap_[id] = var; }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); } // namespace relax diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 38e54d4794a5..feb1f910a42c 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -21,10 +21,10 @@ * \file src/relax/ir/struct_info.cc * \brief Relax struct info. */ +#include #include #include #include -#include namespace tvm { namespace relax { @@ -37,7 +37,7 @@ ObjectStructInfo::ObjectStructInfo(Span span) { TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); -TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { return ObjectStructInfo(span); }); @@ -60,13 +60,11 @@ PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); -TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) { - return PrimStructInfo(dtype, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype") + .set_body_typed([](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }); -TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) { - return PrimStructInfo(value, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromValue") + .set_body_typed([](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { @@ -94,7 +92,7 @@ ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); -TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.ShapeStructInfo") .set_body_typed([](Optional> values, int ndim, Span span) { if (values.defined()) { CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; @@ -135,7 +133,7 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional v TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); -TVM_REGISTER_GLOBAL("relax.TensorStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.TensorStructInfo") .set_body_typed([](Optional shape, Optional dtype, int ndim, VDevice vdevice, Span span) { if (shape.defined()) { @@ -156,7 +154,7 @@ TupleStructInfo::TupleStructInfo(Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); -TVM_REGISTER_GLOBAL("relax.TupleStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.TupleStructInfo") .set_body_typed([](Array fields, Span span) { return TupleStructInfo(fields, span); }); @@ -191,12 +189,12 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); -TVM_REGISTER_GLOBAL("relax.FuncStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfo") .set_body_typed([](Array params, StructInfo ret, bool purity, Span span) { return FuncStructInfo(params, ret, purity, span); }); -TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") +TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") .set_body_typed([](Optional ret, Optional derive_func, bool purity, Span span) { if (derive_func.defined()) { @@ -220,11 +218,10 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { expr->checked_type_ = GetStaticType(struct_info); } -TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { - UpdateStructInfo(expr, struct_info); -}); +TVM_FFI_REGISTER_GLOBAL("relax.UpdateStructInfo") + .set_body_typed([](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }); -TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { +TVM_FFI_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { return GetStructInfo(expr); }); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index d79d8b3fd50d..a44deba0fe94 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -22,13 +22,13 @@ * \brief Relax specific transformation passes. */ #include +#include #include #include #include #include #include #include -#include namespace tvm { namespace relax { @@ -163,7 +163,7 @@ Pass CreateFunctionPass(std::function TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { @@ -383,7 +383,7 @@ Pass CreateDataflowBlockPass( TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); -TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 82b95b556bc2..8b70bcf2c7a5 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -21,8 +21,8 @@ * \file src/relax/ir/type.cc * \brief Relax type system. */ +#include #include -#include namespace tvm { namespace relax { @@ -36,7 +36,7 @@ ShapeType::ShapeType(int ndim, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { return ShapeType(ndim, span); }); @@ -48,7 +48,9 @@ ObjectType::ObjectType(Span span) { TVM_REGISTER_NODE_TYPE(ObjectTypeNode); -TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); +TVM_FFI_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { + return ObjectType(span); +}); TensorType::TensorType(int ndim, DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -68,7 +70,7 @@ TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("relax.TensorType").set_body_typed([](int ndim, DataType dtype, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.TensorType").set_body_typed([](int ndim, DataType dtype, Span span) { return TensorType(ndim, dtype, span); }); @@ -80,7 +82,7 @@ PackedFuncType::PackedFuncType(Span span) { TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); -TVM_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { return PackedFuncType(span); }); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index c32cdc3aacb3..2f6314221ecc 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -36,7 +36,7 @@ Expr allreduce(Expr x, String op_type, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -63,7 +63,7 @@ Expr allgather(Expr x, int num_workers, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -94,7 +94,8 @@ Expr broadcast_from_worker0(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0").set_body_typed(broadcast_from_worker0); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0") + .set_body_typed(broadcast_from_worker0); StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -120,7 +121,7 @@ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0); StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index cdeb537c3d9f..84750c0c9c4c 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -48,7 +48,7 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.annotate_sharding").set_body_typed(annotate_sharding); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.annotate_sharding").set_body_typed(annotate_sharding); StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -73,7 +73,7 @@ Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.redistribute").set_body_typed(redistribute); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute").set_body_typed(redistribute); StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -139,7 +139,7 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, return call; } -TVM_REGISTER_GLOBAL("relax.op.dist.call_tir_local_view").set_body_typed(MakeCallTIRLocalView); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.call_tir_local_view").set_body_typed(MakeCallTIRLocalView); StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -208,7 +208,7 @@ Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { return Call(op, {std::move(input)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard") +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard") .set_body_typed(redistribute_replica_to_shard); TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 5b6550c72903..c45e24df2b13 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -50,7 +50,7 @@ Expr resize2d(Expr data, Expr size, Array roi, String layout, String m return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1 && call->args.size() != 2) { diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 0b8243a88977..a7465db868fe 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -40,7 +40,7 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel }); } -TVM_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { @@ -289,7 +289,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } } -TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); +TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo") + .set_body_typed(InferStructInfoView); Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; @@ -360,7 +361,7 @@ Expr ensure_zero_offset(const Expr& x) { return Call(op, {x}); } -TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index d6e7fa707d73..a084747a5cf3 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -57,8 +57,8 @@ Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr s {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); -TVM_REGISTER_GLOBAL("relax.op.nn.attention_var_len").set_body_typed(attention_var_len); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention_var_len").set_body_typed(attention_var_len); StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index cca50689cb02..f335bb9e7c7b 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -51,7 +51,7 @@ Expr conv1d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -214,7 +214,7 @@ Expr conv2d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -413,7 +413,7 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d); StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -593,7 +593,7 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -730,7 +730,7 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 2296c364fc64..b79690d3a9bd 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -50,7 +50,7 @@ Expr leakyrelu(Expr data, double alpha) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.leakyrelu").set_body_typed(leakyrelu); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.leakyrelu").set_body_typed(leakyrelu); TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) @@ -71,7 +71,7 @@ Expr softplus(Expr data, double beta, double threshold) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) @@ -91,7 +91,7 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) { return Call(op, {data, alpha}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) @@ -112,7 +112,7 @@ Expr softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -170,7 +170,7 @@ Expr log_softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) @@ -191,7 +191,7 @@ Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -234,7 +234,7 @@ Expr pixel_shuffle(Expr data, int upscale_factor) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -383,7 +383,7 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -461,7 +461,7 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -529,7 +529,7 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); @@ -636,7 +636,7 @@ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -695,7 +695,7 @@ Expr dropout(Expr data, double rate) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -763,7 +763,7 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels) { return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") .set_body_typed(cross_entropy_with_logits); TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") @@ -797,7 +797,7 @@ Expr nll_loss(Expr predictions, Expr targets, Optional weights, String red } } -TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 2 || call->args.size() > 3) { diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 391edda9ef38..0161a4d4195d 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -62,7 +62,7 @@ Expr max_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d); StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -175,7 +175,7 @@ Expr max_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -314,7 +314,7 @@ Expr max_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d); StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -401,7 +401,7 @@ Expr avg_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d); TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) @@ -420,7 +420,7 @@ Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) @@ -439,7 +439,7 @@ Expr avg_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d); TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) @@ -470,7 +470,7 @@ Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool1d").set_body_typed(adaptive_avg_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool1d").set_body_typed(adaptive_avg_pool1d); StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -553,7 +553,7 @@ Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -652,7 +652,7 @@ Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool3d").set_body_typed(adaptive_avg_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool3d").set_body_typed(adaptive_avg_pool3d); StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 94fabd2d0ede..c581b20835c6 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -119,7 +119,7 @@ Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs return Call(op, call_args, attrs, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); // call_inplace_packed @@ -238,7 +238,7 @@ Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_i return Call(op, call_args, Attrs(attrs), sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); // call_tir @@ -600,7 +600,7 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); // call_tir_with_grad @@ -652,7 +652,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWithGrad); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWithGrad); // call_tir_inplace @@ -793,7 +793,7 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir_inplace").set_body_typed(MakeCallTIRInplace); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_inplace").set_body_typed(MakeCallTIRInplace); // call_dps_packed @@ -834,7 +834,7 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_ return Call(op, {func, args}, {}, {out_sinfo}); } -TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { @@ -860,7 +860,7 @@ Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) return Call(op, {func, args}, Attrs(), sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) @@ -872,7 +872,7 @@ Expr MakeCallNullValue() { return Call(op, {}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); +TVM_FFI_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); // print @@ -895,7 +895,7 @@ Expr MakePrint(Array vals, StringImm format) { return Call(op, params); } -TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); +TVM_FFI_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); // assert_op @@ -938,7 +938,7 @@ Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { return Call(op, args); } -TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); +TVM_FFI_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); // make_closure @@ -954,7 +954,7 @@ Expr MakeClosure(Expr func, Tuple args) { return Call(op, {func, args}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); // invoke_closure @@ -981,7 +981,7 @@ Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); // invoke_pure_closure @@ -997,7 +997,7 @@ Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); // shape_of @@ -1012,7 +1012,7 @@ Expr MakeShapeOf(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); +TVM_FFI_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); // tensor_to_shape @@ -1046,7 +1046,7 @@ Expr MakeTensorToShape(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); +TVM_FFI_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); // shape_to_tensor StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -1070,7 +1070,7 @@ Expr MakeShapeToTensor(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); // alloc_tensor @@ -1107,7 +1107,7 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); // memory planning alloc_storage @@ -1132,7 +1132,7 @@ Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm stora return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); // memory planning alloc_tensor @@ -1163,7 +1163,7 @@ Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); // memory planning kill_storage @@ -1179,7 +1179,7 @@ Expr MakeMemKillStorage(Expr storage) { return Call(op, {storage}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); // memory planning kill_tensor @@ -1195,7 +1195,7 @@ Expr MakeMemKillTensor(Expr tensor) { return Call(op, {tensor}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); // vm alloc_storage @@ -1219,7 +1219,7 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); // vm alloc_tensor @@ -1257,7 +1257,7 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm d return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); // vm kill_object @@ -1273,7 +1273,7 @@ Expr MakeVMKillObject(Expr obj) { return Call(op, {std::move(obj)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject); // vm call_tir_dyn @@ -1291,7 +1291,7 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { return Call(op, {func, args}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); // builtin stop_lift_params StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { @@ -1309,7 +1309,7 @@ Expr MakeStopLiftParams(Expr x) { return Call(op, {x}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); +TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); // to_vdevice TVM_REGISTER_NODE_TYPE(ToVDeviceAttrs); @@ -1340,7 +1340,7 @@ Expr MakeToVDevice(Expr data, VDevice dst_vdev) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice); // hint_on_device TVM_REGISTER_NODE_TYPE(HintOnDeviceAttrs); @@ -1367,7 +1367,7 @@ Expr MakeHintOnDevice(Expr data, Device device) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); +TVM_FFI_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index df29d9c88503..d7d50f8fa714 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -181,7 +181,7 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c static const Op& op = Op::Get("relax." OpRegName); \ return Call(op, {std::move(x)}, Attrs(), {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + TVM_FFI_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) /************ Utilities ************/ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 6b106f760d5f..ae36d45b3683 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -42,7 +42,7 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {x1, x2}, Attrs(), {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ .add_argument("x1", "Tensor", "The first input tensor.") \ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index c06b51ae2e66..b2355b1af7f0 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -54,7 +54,7 @@ Expr full(Variant> shape, Expr fill_value, Optionalargs.size() != 2) { @@ -96,7 +96,7 @@ Expr full_like(Expr x, Expr fill_value, Optional dtype) { return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -174,8 +174,8 @@ Expr ones_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); -TVM_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); +TVM_FFI_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); TVM_REGISTER_OP("relax.ones") .set_attrs_type() @@ -209,8 +209,8 @@ Expr zeros_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); -TVM_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); +TVM_FFI_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); TVM_REGISTER_OP("relax.zeros") .set_attrs_type() @@ -242,8 +242,8 @@ Expr eye_like(Expr x, PrimValue k, Optional dtype) { return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); -TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); +TVM_FFI_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -319,7 +319,7 @@ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); +TVM_FFI_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -380,8 +380,8 @@ Expr triu(Expr x, Expr k) { Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } -TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); -TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); +TVM_FFI_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); +TVM_FFI_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index bc24285cf9c7..d1d5bbccbcc7 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -40,7 +40,7 @@ Expr astype(Expr x, DataType dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); +TVM_FFI_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -70,7 +70,7 @@ Expr MakeWrapParam(Expr data, DataType dtype) { return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); +TVM_FFI_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index d8aecb3461d4..c25d587052f1 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -35,7 +35,7 @@ Expr no_grad(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad); StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -53,7 +53,7 @@ Expr start_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -75,7 +75,7 @@ Expr end_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -111,7 +111,7 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optiona } } -TVM_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward); StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -145,7 +145,7 @@ Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward); StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -177,7 +177,7 @@ Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward); StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -202,7 +202,7 @@ Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axi return Call(op, {std::move(output_grad), std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward); StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 8f262ce38da7..26978f2fad74 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -47,7 +47,7 @@ Expr take(Expr x, Expr indices, Optional axis) { return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); +TVM_FFI_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); @@ -169,7 +169,7 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid return call; } -TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); +TVM_FFI_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); /* \brief Helper function to unpack a relax::Tuple * @@ -477,7 +477,7 @@ Expr dynamic_strided_slice(Expr x, // return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); } -TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); +TVM_FFI_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 4ca42bffec90..28af375f00a6 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -44,7 +44,7 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype) { return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); +TVM_FFI_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -176,7 +176,7 @@ Expr einsum(Expr operands, String subscripts) { return Call(op, {std::move(operands)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.einsum").set_body_typed(einsum); +TVM_FFI_REGISTER_GLOBAL("relax.op.einsum").set_body_typed(einsum); StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -258,7 +258,7 @@ Expr outer(Expr x1, Expr x2) { return Call(op, {std::move(x1), std::move(x2)}, {}); } -TVM_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer); +TVM_FFI_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer); StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { auto input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 9b653046f76a..e98ba946c512 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -41,7 +41,7 @@ Expr broadcast_to(Expr x, Expr shape) { return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); +TVM_FFI_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -124,7 +124,7 @@ Expr concat(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); +TVM_FFI_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -340,7 +340,7 @@ Expr expand_dims(Expr x, Array axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); +TVM_FFI_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -446,7 +446,7 @@ Expr flatten(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); +TVM_FFI_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -481,7 +481,7 @@ Expr index_tensor(Expr first, Expr tensors) { return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -635,7 +635,7 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); +TVM_FFI_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -702,7 +702,7 @@ Expr permute_dims(Expr x, Optional> axes) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); +TVM_FFI_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); bool IsIdentityPermutation(const std::vector& permutation) { for (int i = 0; i < static_cast(permutation.size()); ++i) { @@ -910,7 +910,7 @@ Expr reshape(Expr x, Variant> shape) { return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); +TVM_FFI_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -997,7 +997,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); +TVM_FFI_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1150,7 +1150,7 @@ Expr squeeze(Expr x, Optional> axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); +TVM_FFI_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1350,7 +1350,7 @@ Expr stack(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); +TVM_FFI_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -1554,7 +1554,7 @@ Expr collapse_sum_like(Expr data, Expr collapse_target) { return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -1600,7 +1600,7 @@ Expr collapse_sum_to(Expr data, Expr shape) { return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); +TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -1655,7 +1655,7 @@ Expr repeat(Expr data, int repeats, Optional axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); +TVM_FFI_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1720,7 +1720,7 @@ Expr tile(Expr data, Array repeats) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); +TVM_FFI_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1783,7 +1783,7 @@ Expr flip(Expr data, Integer axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.flip").set_body_typed(flip); +TVM_FFI_REGISTER_GLOBAL("relax.op.flip").set_body_typed(flip); StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -1820,7 +1820,7 @@ Expr gather_elements(Expr data, Expr indices, int axis) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); +TVM_FFI_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1889,7 +1889,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); +TVM_FFI_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1983,7 +1983,7 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { return Call(op, {data, indices, values}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); +TVM_FFI_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -2106,7 +2106,7 @@ Expr meshgrid(Expr tensors, Optional indexing) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid); +TVM_FFI_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid); StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -2210,7 +2210,7 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements); +TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements); StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2324,7 +2324,7 @@ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); +TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] @@ -2467,7 +2467,7 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } // namespace relax -TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); +TVM_FFI_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 0189ef96780d..78ba6fec34ac 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -44,7 +44,7 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dty return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); +TVM_FFI_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -128,7 +128,7 @@ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_d return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_REGISTER_GLOBAL("relax.op.dequantize").set_body_typed(dequantize); +TVM_FFI_REGISTER_GLOBAL("relax.op.dequantize").set_body_typed(dequantize); StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 35ee4c486b1d..80bbf48fd4f9 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -43,7 +43,8 @@ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indice Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.multinomial_from_uniform").set_body_typed(multinomial_from_uniform); +TVM_FFI_REGISTER_GLOBAL("relax.op.multinomial_from_uniform") + .set_body_typed(multinomial_from_uniform); StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 4df166215414..83e0e246b1bf 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -36,7 +36,7 @@ Expr where(Expr condition, Expr x1, Expr x2) { return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); +TVM_FFI_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -195,7 +195,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs(attrs)); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index f2f840ba3154..e321b326d24e 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -46,7 +46,7 @@ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_i return call; } -TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); +TVM_FFI_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); @@ -144,7 +144,7 @@ Expr nonzero(Expr x) { return Call(op, {std::move(x)}); } -TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); +TVM_FFI_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 9f8545e9b3a2..1cd061084e1a 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -41,7 +41,7 @@ Expr sort(Expr data, int axis, bool descending) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.sort").set_body_typed(sort); +TVM_FFI_REGISTER_GLOBAL("relax.op.sort").set_body_typed(sort); StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { return GetUnaryInputTensorStructInfo(call, ctx); @@ -67,7 +67,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.argsort").set_body_typed(argsort); +TVM_FFI_REGISTER_GLOBAL("relax.op.argsort").set_body_typed(argsort); StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -101,7 +101,7 @@ Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dt return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.topk").set_body_typed(topk); +TVM_FFI_REGISTER_GLOBAL("relax.op.topk").set_body_typed(topk); StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index e4765c8ddb3c..73e74578fc06 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -186,7 +186,7 @@ Expr cumprod(Expr data, Optional axis, Optional dtype, Bool e return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.cumprod").set_body_typed(cumprod); +TVM_FFI_REGISTER_GLOBAL("relax.op.cumprod").set_body_typed(cumprod); TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() @@ -206,7 +206,7 @@ Expr cumsum(Expr data, Optional axis, Optional dtype, Bool ex return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); +TVM_FFI_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index aa312f5df766..331562454efe 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -50,7 +50,7 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index cc265ad9a160..91a6e8d0ae04 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -143,7 +143,7 @@ Expr ewise_fma(Expr x1, Expr x2, Expr x3) { return Call(op, {x1, x2, x3}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); +TVM_FFI_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index f95eb721fc70..828a91dde21d 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -85,7 +85,7 @@ Expr clip(Expr x, Expr min, Expr max) { return Call(op, {std::move(x), std::move(min), std::move(max)}); } -TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); +TVM_FFI_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); /***************** Check operators *****************/ diff --git a/src/relax/testing/transform.cc b/src/relax/testing/transform.cc index eed2329e3d3a..c4e41d5afc1f 100644 --- a/src/relax/testing/transform.cc +++ b/src/relax/testing/transform.cc @@ -35,7 +35,7 @@ tvm::transform::Pass ApplyEmptyCppMutator() { "relax.testing.ApplyEmptyCppMutator", {}); } -TVM_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") .set_body_typed(ApplyEmptyCppMutator); } // namespace testing diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 43fa2357f09d..cb44339f1969 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -215,7 +215,7 @@ Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outpu /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); +TVM_FFI_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); } // namespace transform diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 8d268cfef7f3..46dc803018ea 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -213,7 +213,7 @@ Pass AdjustMatmulOrder() { return CreateFunctionPass(pass_func, 1, "AdjustMatmulOrder", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AdjustMatmulOrder").set_body_typed(AdjustMatmulOrder); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AdjustMatmulOrder").set_body_typed(AdjustMatmulOrder); } // namespace transform } // namespace relax diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index c8a2cef400b6..763a009a24b2 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -201,7 +201,7 @@ Pass AllocateWorkspace() { return CreateModulePass(pass_func, 0, "AllocateWorkspace", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 61d7725e6c13..63521e4e8fe1 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -438,7 +438,7 @@ Pass AlterOpImpl(const Map& op_impl_map, /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); } // namespace transform } // namespace relax diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc index 5127b3df0c26..e2b0fc2c2877 100644 --- a/src/relax/transform/annotate_tir_op_pattern.cc +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -47,7 +47,8 @@ Pass AnnotateTIROpPattern() { return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern") + .set_body_typed(AnnotateTIROpPattern); } // namespace transform diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index 3593b22c10ab..cef74890806d 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -105,7 +105,7 @@ Pass AttachAttrLayoutFreeBuffers() { return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); } -TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") +TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") .set_body_typed(AttachAttrLayoutFreeBuffers); } // namespace transform } // namespace relax diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 97226491a809..905d2bcd838d 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -79,7 +79,7 @@ Pass AttachGlobalSymbol() { return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 6a871edeaa9d..2a5c6f525d50 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -196,7 +196,7 @@ IRModule BindParam(IRModule m, String func_name, Map bind_ return GetRef(new_module); } -TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); namespace transform { @@ -207,7 +207,7 @@ Pass BindParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); } // namespace transform diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 2df9ed1f01a3..49af21c10755 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -148,7 +148,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m } } // namespace -TVM_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars); namespace transform { @@ -170,7 +170,7 @@ Pass BindSymbolicVars(Map binding_map, Optional fun return tvm::transform::CreateModulePass(pass_func, 1, "relax.BindSymbolicVars", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars); } // namespace transform } // namespace relax diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index a011841c1316..982e1ac0c323 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -115,7 +115,7 @@ Pass BundleModelParams(Optional param_tuple_name) { return CreateModulePass(pass_func, 1, "BundleModelParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams); } // namespace transform } // namespace relax diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e3bb2bcbae46..25b4abadc7ff 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -183,7 +183,7 @@ Pass CallTIRRewrite() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); } // namespace transform diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 3d8c65de95a6..ecbb9e77518e 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -591,7 +591,8 @@ Pass CanonicalizeBindings() { "CanonicalizeBindings"); } -TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings") + .set_body_typed(CanonicalizeBindings); } // namespace transform diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 916061aa575f..620186320342 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -387,7 +387,8 @@ Pass CombineParallelMatmul(FCheck check) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul").set_body_typed(CombineParallelMatmul); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul") + .set_body_typed(CombineParallelMatmul); } // namespace transform diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 25a1d1b0ede6..e6db2eb73f3a 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -86,7 +86,7 @@ Pass ComputePrimValue() { return CreateModulePass(pass_func, 0, "ComputePrimValue", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); } // namespace transform diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index 1c4cef892af5..c359afdebc28 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -159,7 +159,7 @@ Pass ConvertToDataflow(int min_size) { return tvm::transform::Sequential({pass, CanonicalizeBindings()}); } -TVM_REGISTER_GLOBAL("relax.transform.ConvertToDataflow").set_body_typed(ConvertToDataflow); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertToDataflow").set_body_typed(ConvertToDataflow); } // namespace transform diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 78670bd51af0..0c06cac75d19 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -350,7 +350,7 @@ Pass ConvertLayout(Map> desired_layouts) { return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); } // namespace transform } // namespace relax diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index aee2c015fc81..51ab6bb23068 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -1016,13 +1016,13 @@ Array> DataflowInplaceAnalysis(const DataflowBlock& bl } // these are exposed only for testing -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") .set_body_typed(DataflowAliasAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") .set_body_typed(DataflowInplaceAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") .set_body_typed([](const IRModule& mod, const Call& call, const Array& inplace_indices) -> Array { ModuleInplaceTransformer transformer(mod); @@ -1031,7 +1031,7 @@ TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") }); // actually exposed -TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") .set_body_typed(DataflowUseInplaceCalls); } // namespace transform diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index abceb517c1dd..7de1da329f88 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -140,7 +140,7 @@ Pass DeadCodeElimination(Array entry_functions) { return CreateModulePass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); +TVM_FFI_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform } // namespace relax diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 1a4cd216256b..eec27f3b7888 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -250,10 +250,10 @@ Pass DecomposeOpsForTraining(Optional func_name) { } } -TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") .set_body_typed(DecomposeOpsForInference); -TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForTraining") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForTraining") .set_body_typed(DecomposeOpsForTraining); } // namespace transform diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 5173df7ef386..8a5ce1db04de 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -221,7 +221,7 @@ Pass EliminateCommonSubexpr(bool call_only) { return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {}); } -TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") +TVM_FFI_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 134eca557264..d7bf2dd95ffb 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -104,7 +104,7 @@ Pass ExpandMatmulOfSum() { return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); } // namespace transform } // namespace relax diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 8d5c833d43fb..1a9afadf7e48 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -178,7 +178,8 @@ Pass ExpandTupleArguments() { "ExpandTupleArguments"); } -TVM_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments").set_body_typed(ExpandTupleArguments); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments") + .set_body_typed(ExpandTupleArguments); } // namespace transform diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 084791bde123..4ccf6c25abc8 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -172,7 +172,7 @@ Pass FewShotTuning(int valid_count, bool benchmark) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FewShotTuning").set_body_typed(FewShotTuning); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FewShotTuning").set_body_typed(FewShotTuning); } // namespace transform } // namespace relax diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 2d916d0391ea..c2bac5daa7f2 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -326,7 +326,7 @@ Pass FoldConstant() { return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } -TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 74141724ee24..f9ffcd930283 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1395,7 +1395,7 @@ FusionPattern::FusionPattern(String name, DFPattern pattern, } TVM_REGISTER_NODE_TYPE(FusionPatternNode); -TVM_REGISTER_GLOBAL("relax.transform.FusionPattern") +TVM_FFI_REGISTER_GLOBAL("relax.transform.FusionPattern") .set_body_typed([](String name, DFPattern pattern, Map annotation_patterns, Optional check, Optional attrs_getter) { return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); @@ -1429,7 +1429,7 @@ Pass FuseOps(int fuse_opt_level) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, bool annotate_codegen, const Array& entry_function_names) { @@ -1444,7 +1444,7 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); } // namespace transform diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 0dc807a5a624..05b7bf4218dd 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -1262,7 +1262,7 @@ Pass FuseTIR() { "FuseTIR"); } -TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); } // namespace transform diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 35c6d9af5ec0..9998b6da93f3 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -787,7 +787,7 @@ Pass Gradient(String func_name, Optional> require_grads, int target_i /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient); +TVM_FFI_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient); } // namespace transform diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 981a5e654c1c..26b106373ff0 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -164,7 +164,7 @@ Function FunctionInlineFunctions(Function func, return Downcast(mutator(std::move(func))); } -TVM_REGISTER_GLOBAL("relax.FunctionInlineFunctions").set_body_typed(FunctionInlineFunctions); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionInlineFunctions").set_body_typed(FunctionInlineFunctions); namespace transform { @@ -219,7 +219,7 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "InlinePrivateFunctions", {}); } -TVM_REGISTER_GLOBAL("relax.transform.InlinePrivateFunctions") +TVM_FFI_REGISTER_GLOBAL("relax.transform.InlinePrivateFunctions") .set_body_typed(InlinePrivateFunctions); } // namespace transform diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 20ec5eb4348f..730f65f701ba 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -265,7 +265,7 @@ Pass KillAfterLastUse() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {}); } -TVM_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); +TVM_FFI_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 41a970abfecf..e5e28cb55375 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -503,7 +503,7 @@ Pass LambdaLift() { return tvm::transform::CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index d1cf4ff7147b..32f63e1e141b 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -259,7 +259,7 @@ Pass LazyGetInput() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput); Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { @@ -274,7 +274,7 @@ Pass LazySetOutput() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput); } // namespace transform } // namespace relax diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 71a55a3d84d1..a0ac6fffb62c 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -404,7 +404,7 @@ Pass LegalizeOps(Optional> cmap, bool enable_warning) /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); } // namespace transform diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index df7242efd462..9013737df5e4 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -867,7 +867,7 @@ Pass LiftTransformParams(Variant> shared_transform) { "LiftTransformParams"); } -TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 13705c0908cc..3bdbfd0b94a9 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -99,7 +99,7 @@ Pass LowerAllocTensor() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor); } // namespace transform } // namespace relax diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 9538249b771b..ffeddd08c401 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -421,7 +421,7 @@ Pass MergeCompositeFunctions() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") .set_body_typed(MergeCompositeFunctions); } // namespace transform diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 57012a40ddaa..cf7b9fc03a50 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -201,10 +201,11 @@ Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { /*traceable*/ true); } -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") .set_body_typed(MetaScheduleApplyDatabase); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod") + .set_body_typed(MetaScheduleTuneIRMod); +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 2c57c89a265c..07ca6a1133e7 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -279,7 +279,7 @@ Pass Normalize() { return CreateFunctionPass(pass_func, 1, "Normalize", {}); } -TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); +TVM_FFI_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); Pass NormalizeGlobalVar() { auto pass_func = [=](IRModule mod, PassContext pc) { @@ -290,7 +290,7 @@ Pass NormalizeGlobalVar() { /*pass_name=*/"NormalizeGlobalVar", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.NormalizeGlobalVar").set_body_typed(NormalizeGlobalVar); +TVM_FFI_REGISTER_GLOBAL("relax.transform.NormalizeGlobalVar").set_body_typed(NormalizeGlobalVar); } // namespace transform diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 15474eb8f8d6..ee4773fb3a24 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -415,7 +415,7 @@ Pass RealizeVDevice() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.RealizeVDevice").set_body_typed(RealizeVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RealizeVDevice").set_body_typed(RealizeVDevice); } // namespace transform } // namespace relax diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index a88f5e0f5629..31e771d2adec 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -88,7 +88,8 @@ Pass RemovePurityChecking() { return CreateFunctionPass(pass_func, 0, "RemovePurityChecking", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RemovePurityChecking").set_body_typed(RemovePurityChecking); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemovePurityChecking") + .set_body_typed(RemovePurityChecking); } // namespace transform diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index b967ad68d126..e170588f60c6 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -336,7 +336,7 @@ Pass RemoveUnusedOutputs() { "RemoveUnusedOutputs"); } -TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); } // namespace transform diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index bc7fa325ccc7..911e427935be 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -250,7 +250,7 @@ Pass RemoveUnusedParameters() { return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") .set_body_typed(RemoveUnusedParameters); } // namespace transform diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index a2023a068aa2..2016c6766c08 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -173,7 +173,7 @@ Pass ReorderPermuteDimsAfterConcat() { return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") +TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") .set_body_typed(ReorderPermuteDimsAfterConcat); } // namespace transform diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 28480a2296f3..4c87cbe8b7e3 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -156,7 +156,7 @@ Pass ReorderTakeAfterMatmul() { return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") +TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") .set_body_typed(ReorderTakeAfterMatmul); } // namespace transform diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 608f72ee1f20..14e98ecad152 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -897,7 +897,7 @@ Pass RewriteCUDAGraph() { return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph); } // namespace transform diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 690f1b723279..a13c23387821 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -165,7 +165,7 @@ Pass RewriteDataflowReshape() { return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") +TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") .set_body_typed(RewriteDataflowReshape); } // namespace transform diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 3940ecd70bd5..7556f2690307 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -219,7 +219,7 @@ Pass RunCodegen(Optional>> target_options, return CreateModulePass(pass_func, 0, "RunCodegen", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index d0edab850652..276ba448cc4b 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -774,7 +774,8 @@ Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { /*pass_name=*/"SplitCallTIRByPattern", // /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern") + .set_body_typed(SplitCallTIRByPattern); } // namespace transform diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 542dba3cf6c6..7990beb04b2e 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -340,7 +340,7 @@ Pass SplitLayoutRewritePreproc() { return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, "SplitLayoutRewritePreproc"); } -TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") +TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") .set_body_typed(SplitLayoutRewritePreproc); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index f3a9860eaa7f..0a51e9cd4acb 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -983,7 +983,8 @@ Pass StaticPlanBlockMemory() { return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); } -TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory); +TVM_FFI_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory") + .set_body_typed(StaticPlanBlockMemory); } // namespace transform } // namespace relax diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index c82d4b573646..531ecefd5d66 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -618,7 +618,7 @@ Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_in return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); } // namespace transform diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index b18ece65a6db..ef1616c83ed8 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -61,7 +61,7 @@ Pass ToNonDataflow() { return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); } // namespace transform diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index a42b1db7712b..1ba78cdc5e2c 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -342,7 +342,7 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", {}); } -TVM_REGISTER_GLOBAL("relax.transform.TopologicalSort") +TVM_FFI_REGISTER_GLOBAL("relax.transform.TopologicalSort") .set_body_typed([](String order_str, String direction_str) -> Pass { TraversalOrder order = [&]() { if (order_str == "depth-first") { diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index c0336f9df335..fedc61019b06 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -311,32 +311,34 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TuningRecordNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") .set_body_typed([](Trace trace, Optional> run_secs) { return TuningRecord(trace, run_secs); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON") + .set_body_typed(TuningRecord::FromJSON); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") .set_body_method(&DatabaseNode::HasWorkload); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") .set_body_method(&DatabaseNode::HasMeasurementRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") .set_body_method(&DatabaseNode::HasTuningRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") .set_body_method(&DatabaseNode::CommitMeasurementRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") .set_body_method(&DatabaseNode::GetMeasurementRecord); TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase") + .set_body_typed(Database::JSONDatabase); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc index 949cf611c20f..5f53b5166725 100644 --- a/src/relax/transform/tuning_api/primitives.cc +++ b/src/relax/transform/tuning_api/primitives.cc @@ -231,41 +231,42 @@ Trace Trace::FromJSON(const ObjectRef& json) { /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ChoiceNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Choice") .set_body_typed([](String transform_func_key, Array transform_func_args, String constr_func_key, Array constr_func_args) { return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") .set_body_method(&ChoiceNode::GetTransformFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") .set_body_method(&ChoiceNode::GetConstrFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") .set_body_method(&ChoiceNode::ApplyTransformFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr").set_body_method(&ChoiceNode::CheckConstr); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); TVM_REGISTER_NODE_TYPE(KnobNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Knob") .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") .set_body_method(&KnobNode::IsValidDecision); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Trace") .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { return Trace(in_mod, knobs, decisions); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod").set_body_method(&TraceNode::SetOutMod); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod").set_body_method(&TraceNode::SetOutMod); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 062ac97a35f7..472f454bc11a 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -104,7 +104,8 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_f return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); } -TVM_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo").set_body_typed(UpdateParamStructInfo); +TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo") + .set_body_typed(UpdateParamStructInfo); } // namespace transform } // namespace relax diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index 5a8346578e7c..d2a1f85be853 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -106,7 +106,7 @@ Pass UpdateVDevice(VDevice new_vdevice, int64_t index) { /*pass_name=*/"UpdateVDevice", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice); } // namespace transform } // namespace relax diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 96fd5578e40a..ab270c08a65d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -245,7 +245,7 @@ Expr GetBoundValue(const Binding& b) { */ Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); } -TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); +TVM_FFI_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); } // namespace relax } // namespace tvm diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index ba3dddb7ae49..7f7d416f88d9 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -22,22 +22,22 @@ * \brief Functions for conversion between fp32 and fp16 */ #include -#include +#include extern "C" { // disable under msvc #ifndef _MSC_VER -TVM_DLL TVM_WEAK uint16_t __gnu_f2h_ieee(float a) { +TVM_DLL TVM_FFI_WEAK uint16_t __gnu_f2h_ieee(float a) { return __truncXfYf2__(a); } -TVM_DLL TVM_WEAK float __gnu_h2f_ieee(uint16_t a) { +TVM_DLL TVM_FFI_WEAK float __gnu_h2f_ieee(uint16_t a) { return __extendXfYf2__(a); } -TVM_DLL TVM_WEAK uint16_t __truncdfhf2(double a) { +TVM_DLL TVM_FFI_WEAK uint16_t __truncdfhf2(double a) { return __truncXfYf2__(a); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc deleted file mode 100644 index 751617914e66..000000000000 --- a/src/runtime/c_runtime_api.cc +++ /dev/null @@ -1,807 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file c_runtime_api.cc - * \brief Device specific implementations - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "object_internal.h" -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -std::string GetCustomTypeName(uint8_t type_code) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_name"); - return f(type_code).cast(); -} - -uint8_t GetCustomTypeCode(const std::string& type_name) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_code"); - return f(type_name).cast(); -} - -bool GetCustomTypeRegistered(uint8_t type_code) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_registered"); - return f(type_code).cast(); -} - -uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { - ICHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string"; - - auto tmp = s.c_str(); - - ICHECK(s.c_str() == tmp); - *scan = s.c_str() + 6; - ICHECK(s.c_str() == tmp); - if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; - ICHECK(s.c_str() == tmp); - *scan += 1; - ICHECK(s.c_str() == tmp); - size_t custom_name_len = 0; - ICHECK(s.c_str() == tmp); - while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') - ++custom_name_len; - ICHECK(s.c_str() == tmp); - if (*(*scan + custom_name_len) != ']') - LOG(FATAL) << "expected closing brace after 'custom' type in" << s; - ICHECK(s.c_str() == tmp); - *scan += custom_name_len + 1; - ICHECK(s.c_str() == tmp); - - auto type_name = s.substr(7, custom_name_len); - ICHECK(s.c_str() == tmp); - return GetCustomTypeCode(type_name); -} - -class DeviceAPIManager { - public: - static const int kMaxDeviceAPI = TVMDeviceExtType_End; - // Get API - static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } - static DeviceAPI* Get(int dev_type, bool allow_missing = false) { - return Global()->GetAPI(dev_type, allow_missing); - } - - private: - std::array api_; - DeviceAPI* rpc_api_{nullptr}; - std::mutex mutex_; - // constructor - DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } - // Global static variable. - static DeviceAPIManager* Global() { - static DeviceAPIManager* inst = new DeviceAPIManager(); - return inst; - } - // Get or initialize API. - DeviceAPI* GetAPI(int type, bool allow_missing) { - if (type < kRPCSessMask) { - if (api_[type] != nullptr) return api_[type]; - std::lock_guard lock(mutex_); - if (api_[type] != nullptr) return api_[type]; - api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); - return api_[type]; - } else { - if (rpc_api_ != nullptr) return rpc_api_; - std::lock_guard lock(mutex_); - if (rpc_api_ != nullptr) return rpc_api_; - rpc_api_ = GetAPI("rpc", allow_missing); - return rpc_api_; - } - } - DeviceAPI* GetAPI(const std::string name, bool allow_missing) { - std::string factory = "device_api." + name; - const auto f = tvm::ffi::Function::GetGlobal(factory); - if (!f.has_value()) { - ICHECK(allow_missing) << "Device API " << name << " is not enabled."; - return nullptr; - } - void* ptr = (*f)().cast(); - return static_cast(ptr); - } -}; - -DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { - return DeviceAPIManager::Get(static_cast(dev.device_type), allow_missing); -} - -void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); -} - -static size_t GetDataAlignment(const DLDataType dtype) { - size_t align = (dtype.bits / 8) * dtype.lanes; - if (align < kAllocAlignment) return kAllocAlignment; - return align; -} - -size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { - size_t size = 1; - for (tvm_index_t i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; - } - LOG(FATAL) << "Device does not support physical mem computation with " - << "specified memory scope: " << mem_scope.value(); - return 0; -} - -void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") { - // by default, we can always redirect to the flat memory allocations - DLTensor temp; - temp.data = nullptr; - temp.device = dev; - temp.ndim = ndim; - temp.dtype = dtype; - temp.shape = const_cast(shape); - temp.strides = nullptr; - temp.byte_offset = 0; - size_t size = GetDataSize(temp); - size_t alignment = GetDataAlignment(temp.dtype); - return AllocDataSpace(dev, size, alignment, dtype); - } - LOG(FATAL) << "Device does not support allocate data space with " - << "specified memory scope: " << mem_scope.value(); - return nullptr; -} - -void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - // by default, we can always redirect to the flat memory copy operation. - size_t nbytes = GetDataSize(*from); - ICHECK_EQ(nbytes, GetDataSize(*to)); - - ICHECK(IsContiguous(*from) && IsContiguous(*to)) - << "CopyDataFromTo only support contiguous array for now"; - CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, - to->device, from->dtype, stream); -} - -void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, Device dev_from, Device dev_to, - DLDataType type_hint, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support CopyDataFromTo."; -} - -void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } - -TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } - -void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} - -TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } - -void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { -} - -//-------------------------------------------------------- -// Error handling mechanism -// ------------------------------------------------------- -// Standard error message format, {} means optional -//-------------------------------------------------------- -// {error_type:} {message0} -// {message1} -// {message2} -// {Stack trace:} // stack traces follow by this line -// {trace 0} // two spaces in the beginning. -// {trace 1} -// {trace 2} -//-------------------------------------------------------- -/*! - * \brief Normalize error message - * - * Parse them header generated by LOG(FATAL) and ICHECK - * and reformat the message into the standard format. - * - * This function will also merge all the stack traces into - * one trace and trim them. - * - * \param err_msg The error message. - * \return normalized message. - */ -std::string NormalizeError(std::string err_msg) { - // ------------------------------------------------------------------------ - // log with header, {} indicates optional - //------------------------------------------------------------------------- - // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0} - // {message1} - // Stack trace: - // {stack trace 0} - // {stack trace 1} - //------------------------------------------------------------------------- - // Normalzied version - //------------------------------------------------------------------------- - // error_type: check_msg message0 - // {message1} - // Stack trace: - // File file_name, line lineno - // {stack trace 0} - // {stack trace 1} - //------------------------------------------------------------------------- - // LEGACY-COMPACT: - // skip python-style error style - // TODO(tqchen) move to new FFI handling - if (err_msg.find("Traceback (most recent call last)") != std::string::npos) { - return err_msg; - } - int line_number = 0; - std::istringstream is(err_msg); - std::string line, file_name, error_type, check_msg; - - // Parse log header and set the fields, - // Return true if it the log is in correct format, - // return false if something is wrong. - auto parse_log_header = [&]() { - // skip timestamp - if (is.peek() != '[') { - getline(is, line); - return true; - } - if (!(is >> line)) return false; - // get filename - while (is.peek() == ' ') is.get(); -#ifdef _MSC_VER // handle volume separator ":" in Windows path - std::string drive; - if (!getline(is, drive, ':')) return false; - if (!getline(is, file_name, ':')) return false; - file_name = drive + ":" + file_name; -#else - if (!getline(is, file_name, ':')) return false; -#endif - // get line number - if (!(is >> line_number)) return false; - // get rest of the message. - while (is.peek() == ' ' || is.peek() == ':') is.get(); - if (!getline(is, line)) return false; - // detect check message, rewrite to remote extra : - if (line.compare(0, 13, "Check failed:") == 0) { - std::string ending = ": "; - size_t end_pos = line.find(ending, 13); - if (end_pos == std::string::npos) return false; - check_msg = line.substr(0, end_pos + ending.size()); - line = line.substr(end_pos + ending.size()); - } - return true; - }; - // if not in correct format, do not do any rewrite. - if (!parse_log_header()) return err_msg; - // Parse error type. - { - size_t start_pos = 0, end_pos; - for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { - } - for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { - char ch = line[end_pos]; - if (ch == ':') { - error_type = line.substr(start_pos, end_pos - start_pos); - break; - } - // [A-Z0-9a-z_.] - if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break; - } - if (error_type.length() != 0) { - // if we successfully detected error_type: trim the following space. - for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; - ++start_pos) { - } - line = line.substr(start_pos); - } else { - // did not detect error_type, use default value. - line = line.substr(start_pos); - error_type = "TVMError"; - } - } - // Separate out stack trace. - std::ostringstream os; - os << error_type << ": " << check_msg << line << '\n'; - - bool trace_mode = true; - std::vector stack_trace; - while (getline(is, line)) { - if (trace_mode) { - if (line.compare(0, 2, " ") == 0) { - stack_trace.push_back(line); - } else { - trace_mode = false; - // remove EOL trailing stacktrace. - if (line.length() == 0) continue; - } - } - if (!trace_mode) { - if (line.compare(0, 11, "Stack trace") == 0) { - trace_mode = true; - } else { - os << line << '\n'; - } - } - } - if (stack_trace.size() != 0 || file_name.length() != 0) { - os << "Stack trace:\n"; - if (file_name.length() != 0) { - os << " File \"" << file_name << "\", line " << line_number << "\n"; - } - // Print out stack traces, optionally trim the c++ traces - // about the frontends (as they will be provided by the frontends). - bool ffi_boundary = false; - for (const auto& line : stack_trace) { - // Heuristic to detect python ffi. - if (line.find("libffi.so") != std::string::npos || - line.find("core.cpython") != std::string::npos) { - ffi_boundary = true; - } - // If the backtrace is not c++ backtrace with the prefix " [bt]", - // then we can stop trimming. - if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) { - ffi_boundary = false; - } - if (!ffi_boundary) { - os << line << '\n'; - } - // The line after TVMFuncCall cound be in FFI. - if (line.find("(TVMFuncCall") != std::string::npos) { - ffi_boundary = true; - } - } - } - return os.str(); -} - -} // namespace runtime -} // namespace tvm - -using namespace tvm::runtime; - -struct WrappedPythonError : Error { - WrappedPythonError() : Error("WrappedPythonError", "", TVM_FFI_TRACEBACK_HERE) {} - explicit WrappedPythonError(WrappedPythonObject obj) - : Error("WrappedPythonError", "", TVM_FFI_TRACEBACK_HERE), obj(std::move(obj)) {} - - WrappedPythonObject obj; -}; - -struct TVMRuntimeEntry { - std::string ret_str; - TVMByteArray ret_bytes; - - std::variant last_error; - std::string last_error_formatted; -}; - -typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; - -const char* TVMGetLastError() { - auto* store = TVMAPIRuntimeStore::Get(); - const auto& last_error = store->last_error; - if (const auto* message = std::get_if(&last_error)) { - return message->c_str(); - } else if (const auto* internal = std::get_if(&last_error)) { - // Use last_error_formatted to store the formatted error message, to avoid - // dangling pointer. - store->last_error_formatted = internal->what(); - return store->last_error_formatted.c_str(); - } else { - return nullptr; - } -} - -void* TVMGetLastPythonError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (auto* wrapped = std::get_if(&last_error)) { - return wrapped->obj.raw_pointer(); - } else { - return nullptr; - } -} - -const char* TVMGetLastBacktrace() { - const auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - static thread_local std::string traceback; - if (const auto* wrapped = std::get_if(&last_error)) { - traceback = wrapped->traceback(); - return traceback.c_str(); - } else if (const auto* wrapped = std::get_if(&last_error)) { - traceback = wrapped->traceback(); - return traceback.c_str(); - } else { - return nullptr; - } -} - -void TVMDropLastPythonError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (std::get_if(&last_error)) { - last_error = ""; - } -} - -int TVMAPIHandleException(const std::exception& e) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - - if (const auto* wrapped = dynamic_cast(&e)) { - last_error = *wrapped; - } else if (const auto* internal = dynamic_cast(&e)) { - last_error = *internal; - } else { - last_error = NormalizeError(e.what()); - } - return -1; -} - -void TVMAPISetLastPythonError(void* obj) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - last_error = WrappedPythonError(WrappedPythonObject(obj)); -} - -void TVMThrowLastError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (auto* wrapped = std::get_if(&last_error)) { - WrappedPythonError wrapped_err; - std::swap(wrapped_err, *wrapped); - throw wrapped_err; - } else if (auto* internal = std::get_if(&last_error)) { - throw *internal; - } else { - // redirect to tvm-ffi error handling. - throw ::tvm::ffi::details::MoveFromSafeCallRaised(); - } -} - -void TVMAPISetLastError(const char* msg) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - last_error = msg; -} - -int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { - API_BEGIN(); - tvm::ffi::Any ret; - ret = Module::LoadFromFile(file_name, format); - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret)); - *out = val.v_obj; - API_END(); -} - -int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { - API_BEGIN(); - ObjectInternal::GetModuleNode(mod)->Import(GetRef(ObjectInternal::GetModuleNode(dep))); - API_END(); -} - -int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, - TVMFunctionHandle* func) { - API_BEGIN(); - tvm::ffi::Function pf = - ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); - if (pf != nullptr) { - tvm::ffi::Any ret = pf; - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret)); - *func = val.v_obj; - } else { - *func = nullptr; - } - API_END(); -} - -int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } - -int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { - API_BEGIN(); - *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name))->get(); - API_END(); -} - -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, - int dtype_bits_hint) { - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - - DLDataType type_hint; - type_hint.code = static_cast(dtype_code_hint); - type_hint.bits = static_cast(dtype_bits_hint); - type_hint.lanes = 1; - - return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast(size), type_hint); -} - -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); - return 0; -} - -int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { - if (*handle == nullptr) { - *handle = reinterpret_cast(1); - return (*f)(cdata); - } - return 0; -} - -int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } - -int TVMByteArrayFree(TVMByteArray* arr) { - if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) { - return 0; // Thread-local storage does not need explicit deleting. - } - - delete arr; - return 0; -} - -int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, - TVMValue* ret_val, int* ret_type_code) { - API_BEGIN(); - tvm::ffi::Any rv; - tvm::ffi::FunctionObj* ffi_func = static_cast(func); - std::vector args_vec(num_args); - tvm::runtime::LegacyTVMArgsToPackedArgs(args, arg_type_codes, num_args, args_vec.data()); - ffi_func->CallPacked(args_vec.data(), args_vec.size(), &rv); - // special handle of certain return types. - if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIDataType || - rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIBytes || - rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIStr) { - // TODO(tvm-team): handle bytes return type here - TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); - if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIDataType) { - e->ret_str = DLDataTypeToString(rv.cast()); - *ret_type_code = kTVMStr; - ret_val->v_str = e->ret_str.c_str(); - } else if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIBytes) { - e->ret_str = rv.cast(); - e->ret_bytes.data = e->ret_str.c_str(); - e->ret_bytes.size = e->ret_str.length(); - *ret_type_code = kTVMBytes; - ret_val->v_handle = &(e->ret_bytes); - } else if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIStr) { - e->ret_str = rv.cast(); - *ret_type_code = kTVMStr; - ret_val->v_str = e->ret_str.c_str(); - } - } else { - MoveAnyToLegacyTVMValue(std::move(rv), ret_val, ret_type_code); - } - API_END(); -} - -int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { - API_BEGIN(); - ICHECK_EQ(num_ret, 1); - tvm::ffi::Any* rv = static_cast(ret); - *rv = LegacyTVMArgValueToAnyView(value[0], type_code[0]); - API_END(); -} - -int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, - TVMFunctionHandle* out) { - API_BEGIN(); - if (fin == nullptr) { - tvm::ffi::Any ret; - ret = tvm::ffi::Function::FromPacked( - [func, resource_handle](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, resource_handle); - if (ret != 0) { - TVMThrowLastError(); - } - }); - TVMValue val; - int type_code; - MoveAnyToLegacyTVMValue(std::move(ret), &val, &type_code); - *out = val.v_handle; - } else { - // wrap it in a shared_ptr, with fin as deleter. - // so fin will be called when the lambda went out of scope. - std::shared_ptr rpack(resource_handle, fin); - tvm::ffi::Any ret; - ret = - tvm::ffi::Function::FromPacked([func, rpack](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, rpack.get()); - - if (ret != 0) { - TVMThrowLastError(); - } - }); - TVMValue val; - val.v_handle = nullptr; - int type_code; - MoveAnyToLegacyTVMValue(std::move(ret), &val, &type_code); - *out = val.v_handle; - } - API_END(); -} - -int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - *out = DeviceAPIManager::Get(dev)->CreateStream(dev); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamCreate").set_body_typed([](DLDevice dev) { - return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); -}); - -int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->FreeStream(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamFree").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); -}); - -int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->SetStream(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_SetStream").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); -}); - -int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->StreamSync(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamSync").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); -}); - -int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, - TVMStreamHandle dst) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, src, dst); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamSyncFromTo") - .set_body_typed([](DLDevice dev, int64_t src, int64_t dst) { - DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), - reinterpret_cast(dst)); - }); - -int TVMCbArgToReturn(TVMValue* value, int* code) { - API_BEGIN(); - AnyView arg = LegacyTVMArgValueToAnyView(*value, *code); - Any rv; - if (auto opt_rv = arg.try_cast>()) { - rv = *std::move(*std::move(opt_rv)); - } else { - rv = arg; - } - MoveAnyToLegacyTVMValue(std::move(rv), value, code); - API_END(); -} - -int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint, - void** out_data) { - API_BEGIN(); - out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint); - API_END(); -} - -int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, DLDataType dtype, - const char* mem_scope, void** out_data) { - API_BEGIN(); - tvm::Optional scope; - if (mem_scope != nullptr) { - scope = tvm::String(std::string(mem_scope)); - } - out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, ndim, shape, dtype, scope); - API_END(); -} - -int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) { - API_BEGIN(); - DeviceAPIManager::Get(dev)->FreeDataSpace(dev, ptr); - API_END(); -} - -int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev_from = from->device; - DLDevice dev_to = to->device; - DLDevice dev = dev_from.device_type != kDLCPU ? dev_from : dev_to; - DeviceAPIManager::Get(dev)->CopyDataFromTo(from, to, stream); - API_END(); -} - -// set device api -TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - DeviceAPIManager::Get(dev)->SetDevice(dev); - }); - -// set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - - DeviceAttrKind kind = static_cast(args[2].cast()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); - if (api != nullptr) { - api->GetAttr(dev, kind, ret); - } else { - *ret = 0; - } - } else { - DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); - } - }); - -TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 75ce6612f43c..50b504d17c46 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -29,10 +29,10 @@ */ #include #include +#include #include #include #include -#include #include @@ -143,7 +143,7 @@ class ConstLoaderModuleNode : public ModuleNode { // Initialize the module with constants. int ret = init(md).cast(); // Report the error if initialization is failed. - ICHECK_EQ(ret, 0) << TVMGetLastError(); + ICHECK_EQ(ret, 0); break; } } @@ -247,7 +247,7 @@ Module ConstLoaderModuleCreate( return Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); } // namespace runtime diff --git a/src/runtime/container.cc b/src/runtime/container.cc deleted file mode 100644 index a789cfc769f2..000000000000 --- a/src/runtime/container.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/container.cc - * \brief Implementations of common containers. - */ -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace runtime { - -// Array -TVM_REGISTER_OBJECT_TYPE(ArrayObj); - -TVM_REGISTER_GLOBAL("runtime.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) { - Array result; - for (int i = 0; i < args.size(); ++i) { - result.push_back(args[i]); - } - *ret = result; -}); - -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem") - .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }); - -TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t { - return static_cast(n->size()); -}); - -// String -TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { - return String(std::move(str)); -}); - -TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { - return std::string(str); -}); - -// Map -TVM_REGISTER_GLOBAL("runtime.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) { - ICHECK_EQ(args.size() % 2, 0); - Map data; - for (int i = 0; i < args.size(); i += 2) { - data.Set(args[i], args[i + 1]); - } - *ret = data; -}); - -TVM_REGISTER_GLOBAL("runtime.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t { - return static_cast(n->size()); -}); - -TVM_REGISTER_GLOBAL("runtime.MapGetItem") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }); - -TVM_REGISTER_GLOBAL("runtime.MapCount") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); - -TVM_REGISTER_GLOBAL("runtime.MapItems").set_body_typed([](const ffi::MapObj* n) -> Array { - Array rkvs; - for (const auto& kv : *n) { - rkvs.push_back(kv.first); - rkvs.push_back(kv.second); - } - return rkvs; -}); - -TVM_REGISTER_GLOBAL("runtime.GetShapeSize").set_body_typed([](ffi::Shape shape) { - return static_cast(shape.size()); -}); - -TVM_REGISTER_GLOBAL("runtime.GetShapeElem").set_body_typed([](ffi::Shape shape, int idx) { - ICHECK_LT(idx, shape.size()); - return shape[idx]; -}); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index 72225f39954f..da5ae3f62cb3 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -21,8 +21,8 @@ * \file src/runtime/contrib/amx/amx_config.cc * \brief extraction of AMX configuration on x86 platforms */ +#include #include -#include namespace tvm { namespace runtime { @@ -76,7 +76,7 @@ void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { _tile_loadconfig(dst->a); } -TVM_REGISTER_GLOBAL("runtime.amx_tileconfig") +TVM_FFI_REGISTER_GLOBAL("runtime.amx_tileconfig") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int rows = args[0].cast(); int cols = args[1].cast(); @@ -90,7 +90,7 @@ TVM_REGISTER_GLOBAL("runtime.amx_tileconfig") }); // register a global packed function in c++,to init the system for AMX config -TVM_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // -----------Detect and request for AMX control---------------------- uint64_t bitmask = 0; int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); diff --git a/src/runtime/contrib/arm_compute_lib/acl_allocator.h b/src/runtime/contrib/arm_compute_lib/acl_allocator.h index d4e72a73314f..a755393209ec 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_allocator.h +++ b/src/runtime/contrib/arm_compute_lib/acl_allocator.h @@ -28,9 +28,9 @@ #include #include #include +#include #include #include -#include #include diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 5687e687cfb6..eeca2fcdf347 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for Arm Compute Library. */ +#include #include -#include #include "../json/json_node.h" #include "../json/json_runtime.h" @@ -593,8 +593,8 @@ runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_ return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") +TVM_FFI_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index cb921aa729a1..aed0080589e0 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -22,9 +22,9 @@ * \brief Simple JSON runtime for Apple BNNS primitives */ +#include #include #include -#include #include #include @@ -562,9 +562,9 @@ runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") .set_body_typed(BNNSJSONRuntime::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 155e1f05f197..4d04d8263447 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include extern "C" { #include @@ -123,7 +123,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -134,7 +134,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index 18840eb55db1..68819d015326 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include extern "C" { #include @@ -46,7 +46,7 @@ struct DNNLSgemmOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 14b74d4736fc..a44cf1b365ec 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ #define TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ +#include #include -#include #include #include diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index f98df0c6d624..33b52e5e375d 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -20,9 +20,9 @@ /*! * \file Use external mkl library call. */ +#include #include #include -#include extern "C" { #include @@ -154,7 +154,7 @@ struct MKLDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -166,7 +166,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") }); // integer matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); @@ -177,7 +177,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -188,7 +188,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index c59068152793..5ee90e29b009 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -1830,8 +1830,8 @@ runtime::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.clml_runtime_create").set_body_typed(CLMLRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_clml") +TVM_FFI_REGISTER_GLOBAL("runtime.clml_runtime_create").set_body_typed(CLMLRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_clml") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h index faada2ddeeb5..4431b63cafcc 100644 --- a/src/runtime/contrib/clml/clml_runtime.h +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -32,9 +32,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 7b2733c4312e..f98c97f68b12 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -20,7 +20,7 @@ /*! * \file coreml_runtime.cc */ -#include +#include #include "coreml_runtime.h" @@ -192,7 +192,7 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.coreml_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = CoreMLRuntimeCreate(args[0], args[1]); }); @@ -249,7 +249,8 @@ Module CoreMLRuntimeLoadFromBinary(void* strm) { return Module(exec); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_coreml").set_body_typed(CoreMLRuntimeLoadFromBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_coreml") + .set_body_typed(CoreMLRuntimeLoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index e3222e3adc40..19d83e624d91 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -514,7 +514,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -539,7 +539,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") }); #if CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); @@ -557,7 +557,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") }); #endif // CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index c9c6cf85c6ba..8f7b6ac1f188 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for CUBLAS. */ +#include #include -#include #include #include @@ -153,9 +153,9 @@ runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 5844f802fd84..53e00fe14199 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -23,7 +23,7 @@ #include "cublas_utils.h" #include -#include +#include #include "../../cuda/cuda_common.h" diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 52c69c81cf08..a19fc192efd1 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -20,9 +20,9 @@ /*! * \file cuDNN kernel calls for backward algorithms. */ +#include #include #include -#include #include "cudnn_utils.h" @@ -185,7 +185,7 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c ret[0] = static_cast(best_algo); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -206,7 +206,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); @@ -225,7 +225,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") data_dtype, conv_dtype, verbose, ret); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -246,7 +246,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") dw, conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 87e6121e74c7..856d796e9038 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -20,9 +20,9 @@ /*! * \file cuDNN kernel calls for the forward algorithm. */ +#include #include #include -#include #include "cudnn_utils.h" @@ -153,7 +153,7 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid ret[0] = static_cast(best_algo); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -174,7 +174,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") dilation_v, x, w, y, bias, conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -219,7 +219,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc index f8b170fe2052..dffce6738907 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -24,8 +24,8 @@ #include "./attention.h" +#include #include -#include #include "../../../cuda/cuda_common.h" #include "../cudnn_utils.h" diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h index 4d0309fb3ba6..ae11764ce02c 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -26,7 +26,7 @@ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ #include -#include +#include #include #include diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 08909a3150c2..eda3b694d7f0 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for CUDNN. */ +#include #include -#include #include #include @@ -237,9 +237,9 @@ runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.cuDNNJSONRuntimeCreate").set_body_typed(cuDNNJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.cuDNNJSONRuntimeCreate").set_body_typed(cuDNNJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cudnn_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cudnn_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 3d7546d4e01b..8e2e85c67524 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -24,8 +24,8 @@ #include "cudnn_utils.h" #include +#include #include -#include #include #include @@ -265,7 +265,7 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { return CuDNNThreadEntry::ThreadLocal(false)->exists(); }); diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index c2b3ac3db84c..aa37acd2c3a9 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -21,8 +21,8 @@ * \file src/runtime/contrib/cudnn/softmax.cc * \brief Use external cudnn softmax function */ +#include #include -#include #include "cudnn_utils.h" @@ -77,12 +77,12 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r entry_ptr->softmax_entry.shape_desc, y->data)); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); }); diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc index e8aaf31fc5f3..e31c5fdfebf8 100644 --- a/src/runtime/contrib/curand/curand.cc +++ b/src/runtime/contrib/curand/curand.cc @@ -17,8 +17,8 @@ * under the License. */ #include -#include -#include +#include +#include #include "../../cuda/cuda_common.h" #include "./helper_cuda_kernels.h" @@ -112,7 +112,7 @@ void RandomFill(DLTensor* tensor) { TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); } -TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); +TVM_FFI_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); } // namespace curand } // namespace runtime diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.h b/src/runtime/contrib/curand/helper_cuda_kernels.h index 582162579a3a..6df29ee69056 100644 --- a/src/runtime/contrib/curand/helper_cuda_kernels.h +++ b/src/runtime/contrib/curand/helper_cuda_kernels.h @@ -20,7 +20,7 @@ #define TVM_RUNTIME_CONTRIB_CURAND_HELPER_CUDA_KERNELS_H_ #include -#include +#include namespace tvm { namespace runtime { diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm.cu index f09925ceecd6..b1e152b1b064 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "group_gemm_runner.cuh" @@ -60,7 +60,7 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr static_cast(out->data), stream); } -TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") .set_body_typed(tvm_cutlass_group_gemm_sm90); } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu index bf257622f1f9..d9bd0a33ee25 100644 --- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "../cublas/cublas_utils.h" #include "blockwise_scaled_gemm_runner.cuh" @@ -153,9 +153,9 @@ void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a } } -TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") +TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm); -TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") +TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm); } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 485929570592..5146e62d8c5d 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "../cublas/cublas_utils.h" #include "gemm_runner.cuh" @@ -77,15 +77,15 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray } } -TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); -TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); -TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm.cu index fd528a22cc1a..104010f4c8ab 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include "group_gemm_runner.cuh" @@ -66,15 +66,15 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr static_cast(out->data), stream); } -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index 5fded82762a3..7cc053712b86 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -17,9 +17,9 @@ * under the License. */ +#include #include #include -#include #include "cutlass_kernels/cutlass_preprocessors.h" @@ -35,7 +35,7 @@ namespace runtime { // black box. // // The preprocessing functions are defined in C++, so we need to copy the input weight to CPU. -TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") +TVM_FFI_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) { bool is_2d = packed_weight->ndim == 2; int num_experts = is_2d ? 1 : packed_weight->shape[0]; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 3e73b19116ee..9cc053ec7ca4 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -348,7 +348,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ } // DNNL Conv2d single OP -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto weights = args[1].cast(); diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index b06b17c17d8e..154ee12790f7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for DNNL. */ +#include #include -#include #include #include @@ -927,9 +927,9 @@ runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index 04e06d9c9e94..f12467a67e64 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -25,9 +25,9 @@ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ -#include +#include +#include #include -#include #include diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 2a2462786327..5d706836e6ce 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace tvm { namespace runtime { @@ -68,7 +68,7 @@ Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = EdgeTPURuntimeCreate(args[0], args[1]); }); diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index c85f15cc743a..07331e33defe 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -20,9 +20,9 @@ /*! * \file Use external hipblas library call. */ +#include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -407,7 +407,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -430,7 +430,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 2cd1223bc654..3f4be327c4b2 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for HIPBLAS. */ +#include #include -#include #include #include @@ -141,9 +141,10 @@ runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate").set_body_typed(HipblasJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate") + .set_body_typed(HipblasJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 02d91646518c..6facbb232b2c 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -23,7 +23,7 @@ #include "hipblas_utils.h" #include -#include +#include #include "../../rocm/rocm_common.h" diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 19eec4a0a026..247863c56a99 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -20,9 +20,9 @@ /*! * \file Use external miopen utils function */ +#include #include #include -#include #include @@ -34,7 +34,7 @@ namespace miopen { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); @@ -148,7 +148,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ret[0] = static_cast(best_algo); }); -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index b750e56c7e81..bb091fdf7aa1 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -23,7 +23,7 @@ #include "miopen_utils.h" #include -#include +#include #include #include diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index 021d0387defb..10289f22bdda 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -21,8 +21,8 @@ * \file src/runtime/contrib/miopen/softmax.cc * \brief Use external miopen softmax function */ +#include #include -#include #include "miopen_utils.h" @@ -79,12 +79,12 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); } -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); }); -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 4200477b2713..dbbb92dd05f7 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -24,7 +24,7 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto buf = args[0].cast(); auto img = args[1].cast(); @@ -57,7 +57,7 @@ imageIndex:0]; }); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto img = args[0].cast(); auto buf = args[1].cast(); @@ -76,7 +76,7 @@ buf -> dtype, nullptr); }); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { // MPS-NHWC auto data = args[0].cast(); diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 77eb6dd03dd3..51285251c82e 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,7 +24,7 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index c2b7e3c7aa99..1dd1a2c1e3fc 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -26,10 +26,10 @@ #import #include +#include #include #include #include -#include #include diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index 01cfb385c7f5..3b21ba0e5dc5 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -476,9 +476,9 @@ bool MarvellHardwareModuleNode::use_dpdk_cb = false; ml_tvmc_cb MarvellHardwareModuleNode::tvmc_cb_ = {}; ml_dpdk_cb MarvellHardwareModuleNode::dpdk_cb_ = {}; -TVM_REGISTER_GLOBAL("runtime.mrvl_hw_runtime_create") +TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_hw_runtime_create") .set_body_typed(MarvellHardwareModuleRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_hw") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_hw") .set_body_typed(MarvellHardwareModuleNode::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 186cc3b3a859..701ae6ed8dcd 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -24,9 +24,9 @@ #include #include +#include #include #include -#include #include #include @@ -149,9 +149,9 @@ runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.mrvl_runtime_create") +TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_runtime_create") .set_body_typed(MarvellSimulatorModuleRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim") .set_body_typed(MarvellSimulatorModuleNode::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index b062a50dccb5..c63bafcd0089 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -25,8 +25,8 @@ #include "mrvl_sw_runtime_lib.h" #include +#include #include -#include #include #include diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 7ddbcb34ad02..8819cfd2fc4a 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include @@ -348,9 +348,10 @@ runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create").set_body_typed(MSCTensorRTRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create") + .set_body_typed(MSCTensorRTRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/mscclpp/allreduce.cu b/src/runtime/contrib/mscclpp/allreduce.cu index 7ead504340be..a5bebbc56167 100644 --- a/src/runtime/contrib/mscclpp/allreduce.cu +++ b/src/runtime/contrib/mscclpp/allreduce.cu @@ -19,7 +19,7 @@ #include #include -#include +#include #include "msccl.cuh" diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index c63098873da1..0fcf9fded0a8 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include @@ -240,9 +240,9 @@ runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& grap return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index cac0c1bc050a..090457829e69 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -19,9 +19,9 @@ #include #include #include +#include #include #include -#include #include "../../cuda/cuda_common.h" @@ -107,11 +107,11 @@ void InitNVSHMEMWrapper(String args) { InitNVSHMEM(uid_64, num_workers, worker_id_start); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") .set_body_typed(InitNVSHMEMWrapper); } // namespace runtime diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu index cf3a9958f895..2dad73707df7 100644 --- a/src/runtime/contrib/nvshmem/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include template __device__ int64_t calc_flattened_index(int shape[dim], int index[dim]) { @@ -329,5 +329,5 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages, return 0; } -TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); -TVM_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage); +TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); +TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage); diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 5f7b181e27e3..86427eaa60dd 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -18,9 +18,9 @@ */ #include #include +#include #include #include -#include #include @@ -92,14 +92,14 @@ NDArray NVSHMEMEmpty(ffi::Shape shape, DataType dtype, Device device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); void NVSHMEMFinalize() { NVSHMEMAllocator::Global()->Clear(); nvshmem_finalize(); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index 9f98890b93ac..882cee36b246 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -290,7 +290,7 @@ MetricCollector CreatePAPIMetricCollector(Map> metr TVM_REGISTER_OBJECT_TYPE(PAPIEventSetNode); TVM_REGISTER_OBJECT_TYPE(PAPIMetricCollectorNode); -TVM_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") .set_body_typed([](Map> metrics) { return PAPIMetricCollector(metrics); }); diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index ed4e1a3fad38..8f05a7241b02 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -21,9 +21,9 @@ * \file External random functions for tensor. */ #include +#include #include #include -#include #include #include @@ -69,7 +69,7 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } -TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.randint") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); int64_t low = args[0].cast(); @@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") }) }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.uniform") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double low = args[0].cast(); @@ -112,7 +112,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") entry->random_engine.SampleUniform(out, low, high); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.normal") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double loc = args[0].cast(); @@ -121,14 +121,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") entry->random_engine.SampleNormal(out, loc, scale); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); auto out = args[0].cast(); entry->random_engine.RandomFill(out); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) -> void { const auto curand = tvm::ffi::Function::GetGlobal("runtime.contrib.curand.RandomFill"); auto out = args[0].cast(); diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 88c6071e1efd..2969d7fd0e5e 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -23,9 +23,9 @@ #include "rocblas.h" #include +#include #include #include -#include namespace tvm { namespace contrib { @@ -65,7 +65,7 @@ struct RocBlasThreadEntry { typedef dmlc::ThreadLocalStore RocBlasThreadStore; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); @@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") ldc)); }); -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index f413af696661..62639e684055 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -22,7 +22,9 @@ */ #include -#include +#include +#include +#include #include #include @@ -77,7 +79,7 @@ struct float16 { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto sort_num = args[1].cast(); @@ -216,7 +218,7 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto output = args[1].cast(); @@ -229,8 +231,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") "input ndim " << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { @@ -312,7 +314,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.sort") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto output = args[1].cast(); @@ -442,7 +444,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.topk") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); DLTensor* values_out = nullptr; @@ -467,8 +469,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataTypeToString(indices_out->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = + (indices_out == nullptr) ? "int64" : ffi::DLDataTypeToString(indices_out->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index e1f205e22f10..a8bd43127258 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include @@ -524,9 +524,9 @@ runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& g return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 990475069574..74cfcad3e650 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include namespace tvm { namespace runtime { @@ -183,11 +183,11 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.tflite_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index aa1befaeef32..19f82b1855b4 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include @@ -232,7 +232,7 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_GE(args.num_args, 4); auto input = args[0].cast(); @@ -279,7 +279,7 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr); } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_GE(args.num_args, 5); auto keys_in = args[0].cast(); @@ -394,7 +394,7 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4); auto data = args[0].cast(); diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 2b59044f844c..15e57bd297d4 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -735,7 +735,7 @@ void single_query_cached_kv_attention_v2( } } -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") .set_body_typed([](const DLTensor* query, const DLTensor* key_cache, const DLTensor* value_cache, const DLTensor* block_tables, const DLTensor* context_lens, int block_size, @@ -759,10 +759,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") }); // Expose for testing -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1") .set_body_typed(single_query_cached_kv_attention_v1); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2") .set_body_typed(single_query_cached_kv_attention_v2); } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc index aea50aa47a5c..dd2b7bd5bb37 100644 --- a/src/runtime/contrib/vllm/cache_alloc.cc +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -17,8 +17,8 @@ * under the License. */ #include +#include #include -#include namespace tvm { namespace runtime { @@ -48,7 +48,7 @@ Array AllocateKVCache(int head_size, int num_layers, int num_heads, int return cache; } -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); } // namespace vllm } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index b53cd094c1aa..d762010427d4 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include #include @@ -130,7 +130,7 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t* value_cache namespace tvm { namespace runtime { -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { int num_tokens = key->shape[0]; @@ -155,7 +155,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") return Array{key_cache, value_cache}; }); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") .set_body_typed([](NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { int num_tokens = slot_mapping->shape[0]; int num_heads = value_cache->shape[1]; @@ -184,7 +184,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") return Array{key, value}; }); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") .set_body_typed([](Array key_value_caches, NDArray block_mapping) { auto num_layers = key_value_caches.size() / 2; auto num_pairs = block_mapping->shape[0] / 2; diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index df2271e64732..68594f0769fe 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -21,9 +21,9 @@ * \file cpu_device_api.cc */ #include +#include #include #include -#include #include #include @@ -150,7 +150,7 @@ void CPUDeviceAPI::FreeWorkspace(Device dev, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -TVM_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CPUDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 1dc928e77801..399312e19321 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include @@ -286,15 +286,16 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.cuda_host") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global(); + *rv = static_cast(ptr); + }); class CUDATimerNode : public TimerNode { public: @@ -329,7 +330,7 @@ class CUDATimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(CUDATimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) { return Timer(make_object()); }); @@ -342,9 +343,9 @@ TVM_DLL String GetCudaFreeMemory() { return ss.str(); } -TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); +TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); -TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { return static_cast(CUDAThreadEntry::ThreadLocal()->stream); }); @@ -354,7 +355,7 @@ TVM_DLL int GetCudaDeviceCount() { return count; } -TVM_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); +TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index db01b76cb531..acb2dc6cdf11 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include @@ -290,10 +290,10 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index ae7c057be0cc..726df80de8bc 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -19,8 +19,8 @@ #include "../../../3rdparty/nvbench/l2_cache_flush.h" #include +#include #include -#include #include "cuda_common.h" @@ -32,11 +32,12 @@ typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } -TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; - cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; - L2Flush::ThreadLocal()->Flush(stream); -}); +TVM_FFI_REGISTER_GLOBAL("l2_cache_flush_cuda") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; + cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + L2Flush::ThreadLocal()->Flush(stream); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index eee8ed3cb6f9..483b8cdb592b 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -33,12 +33,12 @@ // #include // #include -// #include +// #include // #include // #include // #include -// #include +// #include // #include namespace tvm { @@ -46,7 +46,7 @@ namespace debug { using namespace tvm::runtime; -// TVM_REGISTER_GLOBAL("tvm.debug.Test").set_body_typed([](PrimExpr value) { +// TVM_FFI_REGISTER_GLOBAL("tvm.debug.Test").set_body_typed([](PrimExpr value) { // LOG(INFO) << value; // return value; // }); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc new file mode 100644 index 000000000000..a80d6ebdbda6 --- /dev/null +++ b/src/runtime/device_api.cc @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file c_runtime_api.cc + * \brief Device specific implementations + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +class DeviceAPIManager { + public: + static const int kMaxDeviceAPI = TVMDeviceExtType_End; + // Get API + static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } + static DeviceAPI* Get(int dev_type, bool allow_missing = false) { + return Global()->GetAPI(dev_type, allow_missing); + } + + private: + std::array api_; + DeviceAPI* rpc_api_{nullptr}; + std::mutex mutex_; + // constructor + DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } + // Global static variable. + static DeviceAPIManager* Global() { + static DeviceAPIManager* inst = new DeviceAPIManager(); + return inst; + } + // Get or initialize API. + DeviceAPI* GetAPI(int type, bool allow_missing) { + if (type < kRPCSessMask) { + if (api_[type] != nullptr) return api_[type]; + std::lock_guard lock(mutex_); + if (api_[type] != nullptr) return api_[type]; + api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); + return api_[type]; + } else { + if (rpc_api_ != nullptr) return rpc_api_; + std::lock_guard lock(mutex_); + if (rpc_api_ != nullptr) return rpc_api_; + rpc_api_ = GetAPI("rpc", allow_missing); + return rpc_api_; + } + } + DeviceAPI* GetAPI(const std::string name, bool allow_missing) { + std::string factory = "device_api." + name; + const auto f = tvm::ffi::Function::GetGlobal(factory); + if (!f.has_value()) { + ICHECK(allow_missing) << "Device API " << name << " is not enabled."; + return nullptr; + } + void* ptr = (*f)().cast(); + return static_cast(ptr); + } +}; + +DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { + return DeviceAPIManager::Get(static_cast(dev.device_type), allow_missing); +} + +void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); +} + +static size_t GetDataAlignment(const DLDataType dtype) { + size_t align = (dtype.bits / 8) * dtype.lanes; + if (align < kAllocAlignment) return kAllocAlignment; + return align; +} + +size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { + if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + size_t size = 1; + for (int i = 0; i < arr.ndim; ++i) { + size *= static_cast(arr.shape[i]); + } + size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; + return size; + } + LOG(FATAL) << "Device does not support physical mem computation with " + << "specified memory scope: " << mem_scope.value(); + return 0; +} + +void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope) { + if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") { + // by default, we can always redirect to the flat memory allocations + DLTensor temp; + temp.data = nullptr; + temp.device = dev; + temp.ndim = ndim; + temp.dtype = dtype; + temp.shape = const_cast(shape); + temp.strides = nullptr; + temp.byte_offset = 0; + size_t size = GetDataSize(temp); + size_t alignment = GetDataAlignment(temp.dtype); + return AllocDataSpace(dev, size, alignment, dtype); + } + LOG(FATAL) << "Device does not support allocate data space with " + << "specified memory scope: " << mem_scope.value(); + return nullptr; +} + +void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + // by default, we can always redirect to the flat memory copy operation. + size_t nbytes = GetDataSize(*from); + ICHECK_EQ(nbytes, GetDataSize(*to)); + + ICHECK(IsContiguous(*from) && IsContiguous(*to)) + << "CopyDataFromTo only support contiguous array for now"; + CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, + to->device, from->dtype, stream); +} + +void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, Device dev_from, Device dev_to, + DLDataType type_hint, TVMStreamHandle stream) { + LOG(FATAL) << "Device does not support CopyDataFromTo."; +} + +void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } + +TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } + +void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} + +TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } + +void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { +} + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamCreate").set_body_typed([](DLDevice dev) { + return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); +}); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamFree") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_SetStream") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSync") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSyncFromTo") + .set_body_typed([](DLDevice dev, int64_t src, int64_t dst) { + DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), + reinterpret_cast(dst)); + }); + +// set device api +TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + DeviceAPIManager::Get(dev)->SetDevice(dev); + }); + +// set device api +TVM_FFI_REGISTER_GLOBAL("runtime.GetDeviceAttr") + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + + DeviceAttrKind kind = static_cast(args[2].cast()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); + if (api != nullptr) { + api->GetAttr(dev, kind, ret); + } else { + *ret = 0; + } + } else { + DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); + } + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.TVMSetStream") + .set_body_typed([](int device_type, int device_id, void* stream) { + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + DeviceAPIManager::Get(dev)->SetStream(dev, stream); + }); +} // namespace runtime +} // namespace tvm + +using namespace tvm::runtime; + +int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { + TVM_FFI_SAFE_CALL_BEGIN(); + *func = const_cast( + static_cast(mod_node)->GetFuncFromEnv(func_name)->get()); + TVM_FFI_SAFE_CALL_END(); +} + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, + int dtype_bits_hint) { + DLDevice dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast(size), type_hint); +} + +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + DLDevice dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); + return 0; +} + +int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { + if (*handle == nullptr) { + *handle = reinterpret_cast(1); + return (*f)(cdata); + } + return 0; +} diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 6b17c7ba3aa8..034a1cf56524 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -18,9 +18,9 @@ */ #include "./bcast_session.h" +#include #include #include -#include #include diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 4cd3e98d8862..b3d04d1d5b6b 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -18,11 +18,11 @@ */ #include #include +#include #include #include #include #include -#include #include #include @@ -121,9 +121,9 @@ void SyncWorker() { } } -TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); -TVM_REGISTER_GLOBAL("runtime.disco.empty") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty") .set_body_typed([](ffi::Shape shape, DataType dtype, Device device, bool worker0_only, bool in_group) -> Optional { int worker_id = WorkerId(); @@ -137,37 +137,39 @@ TVM_REGISTER_GLOBAL("runtime.disco.empty") } }); -TVM_REGISTER_GLOBAL("runtime.disco.allreduce") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.allreduce") .set_body_typed([](NDArray send, ffi::Shape reduce_kind, bool in_group, NDArray recv) { int kind = IntegerFromShape(reduce_kind); CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; AllReduce(send, static_cast(kind), in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); -TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); -TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); -TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ffi::Shape { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0") + .set_body_typed(BroadcastFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ffi::Shape { return ffi::Shape({WorkerId()}); }); -TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { return WorkerId(); }); -TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { return DiscoWorker::ThreadLocal()->default_device; }); -TVM_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core").set_body_typed([](ffi::Shape cpu_ids) { - int worker_id = WorkerId(); - ICHECK_LT(worker_id, static_cast(cpu_ids.size())); - const auto f_set_thread_affinity = - tvm::ffi::Function::GetGlobalRequired("tvm.runtime.threading.set_current_thread_affinity"); - f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core") + .set_body_typed([](ffi::Shape cpu_ids) { + int worker_id = WorkerId(); + ICHECK_LT(worker_id, static_cast(cpu_ids.size())); + const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired( + "tvm.runtime.threading.set_current_thread_affinity"); + f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index 2ae5be1d453b..778ecc16e5a2 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -18,9 +18,9 @@ */ #include +#include #include #include -#include #include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h" #include "../../cuda/cuda_common.h" @@ -212,11 +212,10 @@ memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType dtype_hint) return storage; } -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage); -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear").set_body_typed([]() { - CUDAIPCMemoryAllocator::Global()->Clear(); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear") + .set_body_typed([]() { CUDAIPCMemoryAllocator::Global()->Clear(); }); /******************** CUDAIPCMemoryObj ********************/ diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index d969005f9476..fa7ef040f3ed 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -18,9 +18,9 @@ */ #include +#include #include #include -#include #include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h" #include "../nccl/nccl_context.h" @@ -112,7 +112,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { ctx->GetDefaultStream()); } -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce); } // namespace cuda_ipc } // namespace nccl diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index 7f98feacd83b..b01d378c447f 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -16,11 +16,11 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include #include -#include #include "../../support/process_id.h" #include "./protocol.h" diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 9c25d4abb68e..6cd012b64e11 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include @@ -294,7 +294,7 @@ void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, proxy.MainLoop(); } -TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") .set_body_typed(RemoteSocketSessionEntryPoint); Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, @@ -303,9 +303,9 @@ Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, c return Session(n); } -TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); -TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") .set_body_typed([](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " << num_workers_per_node << " workers per node, and " << num_groups << " groups."; diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index fee20f7cdeda..59624ac0bfb1 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -21,10 +21,10 @@ #define __STDC_FORMAT_MACROS #endif #include +#include #include #include #include -#include #include #include @@ -406,15 +406,15 @@ Array ShardLoaderObj::LoadAllPresharded() const { return params; } -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad") .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->Load(IntegerFromShape(weight_index)); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " @@ -422,14 +422,15 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") return loader->LoadPresharded(IntegerFromShape(weight_index)); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](ObjectRef loader_obj) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadAll(); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll") + .set_body_typed([](ObjectRef loader_obj) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " + << loader_obj->GetTypeKey(); + return loader->LoadAll(); + }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") .set_body_typed([](ObjectRef loader_obj) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " @@ -437,7 +438,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") return loader->LoadAllPresharded(); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0") .set_body_typed([](ObjectRef loader_obj, int param_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 9a0b3fb442d5..8095cbeeea4a 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -325,41 +325,42 @@ void SyncWorker() { StreamSynchronize(stream); } -TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String { return TVM_DISCO_CCL_NAME; }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") .set_body_typed(InitCCLPerWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; nccl::AllReduce(send, static_cast(kind), in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") .set_body_typed([](NDArray send, bool in_group, NDArray recv) { nccl::AllGather(send, in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") .set_body_typed(BroadcastFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") .set_body_typed(ScatterFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") .set_body_typed(GatherToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") .set_body_typed(RecvFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") .set_body_typed(SendToNextGroup); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") .set_body_typed(RecvFromPrevGroup); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") .set_body_typed(SendToWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") .set_body_typed(RecvFromWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker") + .set_body_typed(SyncWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME - ".test_send_to_next_group_recv_from_prev_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME + ".test_send_to_next_group_recv_from_prev_group") .set_body_typed([](NDArray buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; @@ -373,7 +374,7 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME } }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") .set_body_typed([](NDArray buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index d70efebdc844..fff165bfdd04 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -21,10 +21,10 @@ #define TVM_RUNTIME_DISCO_NCCL_NCCL_CONTEXT_H_ #include -#include +#include +#include #include #include -#include #include "../../../support/process_id.h" #include "../utils.h" diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 6f2042aa2529..eff03ea2536b 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -16,11 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include +#include #include #include #include -#include #include #include @@ -197,8 +197,8 @@ void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_f worker.MainLoop(); } -TVM_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); -TVM_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index f0e6cd28a3f5..9536c2911d1b 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -21,10 +21,10 @@ #include #include -#include +#include +#include #include #include -#include #include #include diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index 467888c65768..f2edfd59a27e 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -16,10 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include -#include namespace tvm { namespace runtime { @@ -32,27 +32,27 @@ struct SessionObj::FFI { TVM_REGISTER_OBJECT_TYPE(DRefObj); TVM_REGISTER_OBJECT_TYPE(SessionObj); -TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); -TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") .set_body_method(&DRefObj::DebugGetFromRemote); -TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom").set_body_method(&DRefObj::DebugCopyFrom); -TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom").set_body_method(&DRefObj::DebugCopyFrom); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") .set_body_method(&SessionObj::GetNumWorkers); -TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") .set_body_method(&SessionObj::GetGlobalFunc); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") .set_body_method(&SessionObj::CopyFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") .set_body_method(&SessionObj::CopyToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionObj::SyncWorker); -TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionObj::SyncWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // .set_body_method(&SessionObj::InitCCL); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Session self = args[0].cast(); *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); }); -TVM_REGISTER_GLOBAL("runtime.disco.SessionShutdown").set_body_method(&SessionObj::Shutdown); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionShutdown").set_body_method(&SessionObj::Shutdown); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 1f1318410226..c03a1bd9f4fd 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -17,7 +17,7 @@ * under the License. */ #include -#include +#include #include #include diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 7ae4971a85f1..d64d893ce12b 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,10 +21,10 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ +#include #include #include #include -#include #include "library_module.h" @@ -149,7 +149,7 @@ ObjectPtr CreateDSOLibraryObject(std::string library_path) { return n; } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_so") .set_body_typed([](std::string library_path, std::string) { ObjectPtr n = CreateDSOLibraryObject(library_path); return CreateModuleFromLibrary(n); diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index d8d10f885234..2aa377b9f8bd 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -24,8 +24,8 @@ #include #include +#include #include -#include #include #include @@ -237,22 +237,23 @@ std::string SaveParams(const Map& params) { return bytes; } -TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map& params) { - std::string s = ::tvm::runtime::SaveParams(params); - return ffi::Bytes(std::move(s)); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.SaveParams") + .set_body_typed([](const Map& params) { + std::string s = ::tvm::runtime::SaveParams(params); + return ffi::Bytes(std::move(s)); + }); -TVM_REGISTER_GLOBAL("runtime.SaveParamsToFile") +TVM_FFI_REGISTER_GLOBAL("runtime.SaveParamsToFile") .set_body_typed([](const Map& params, const String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); SaveParams(&strm, params); }); -TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const ffi::Bytes& s) { +TVM_FFI_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); }); -TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { +TVM_FFI_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); return LoadParams(&strm); }); diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h index 8cb8a3209514..b426825fc21f 100644 --- a/src/runtime/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ -#include +#include #include #include #include diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index c959e39e1d39..4c95d68b2dc3 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -22,9 +22,9 @@ */ #include "hexagon_common.h" +#include #include #include -#include #include #include @@ -56,7 +56,7 @@ class HexagonTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(HexagonTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.hexagon").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.hexagon").set_body_typed([](Device dev) { return Timer(make_object()); }); } // namespace hexagon @@ -89,7 +89,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } } // namespace detail -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); *rv = CreateModuleFromLibrary(n); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 40294324018b..0bc7e2b80194 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -24,9 +24,9 @@ #include "hexagon_device_api.h" #include +#include #include #include -#include #include #include @@ -190,7 +190,7 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto dst = args[0].cast(); auto src = args[1].cast(); @@ -209,7 +209,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); void* dst = args[1].cast(); @@ -226,7 +226,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") *rv = static_cast(ret); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_wait") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); int inflight = args[1].cast(); @@ -235,21 +235,21 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -274,7 +274,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.free_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -291,28 +291,29 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->AcquireResources(); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.release_resources") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.release_resources") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->ReleaseResources(); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); *rv = static_cast(api->VtcmPool()->VtcmDeviceBytes()); }); -TVM_REGISTER_GLOBAL("device_api.hexagon").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global(); + *rv = static_cast(ptr); + }); } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 6ed2a2757f68..a5a8de45357a 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -24,8 +24,8 @@ #include "hexagon_module.h" #include +#include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/runtime/hexagon/hexagon_thread_manager.h index 9bf6bb6efe64..31f3d0466972 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.h +++ b/src/runtime/hexagon/hexagon_thread_manager.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ -#include +#include #include #include diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index 88b8f1470cf3..18c89722f4b0 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ -#include +#include #include #include #include diff --git a/src/runtime/hexagon/ops/conv2d.h b/src/runtime/hexagon/ops/conv2d.h index 79bd0217179b..5865d46117a0 100644 --- a/src/runtime/hexagon/ops/conv2d.h +++ b/src/runtime/hexagon/ops/conv2d.h @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include #include diff --git a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc index 5c764355aa58..5f171894d9cd 100644 --- a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -44,8 +44,7 @@ // 4: int stride_h // 5: int stride_w // 6: DLTensor output (NHWC) -extern "C" int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle); +extern "C" int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val); namespace tvm { namespace runtime { @@ -403,26 +402,27 @@ void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) } // namespace runtime } // namespace tvm -int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle) { +int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; ICHECK_EQ(num_args, 7) << "Unexpected number of arguments"; - ICHECK_EQ(type_codes[0], kTVMDLTensorHandle) + ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(type_codes[1], kTVMDLTensorHandle) + ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(type_codes[2], kDLInt) + ICHECK_EQ(args[2].type_index, kTVMFFIInt) << "Third argument is expected to be the pad_top offset"; // pad_top offset - ICHECK_EQ(type_codes[3], kDLInt) + ICHECK_EQ(args[3].type_index, kTVMFFIInt) << "Fourth argument is expected to be the pad_left offset"; // pad_left offset - ICHECK_EQ(type_codes[4], kDLInt) << "Fifth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(type_codes[5], kDLInt) << "Sixth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(type_codes[6], kTVMDLTensorHandle) + ICHECK_EQ(args[4].type_index, kTVMFFIInt) + << "Fifth argument is expected to be the stride_h"; // stride_h + ICHECK_EQ(args[5].type_index, kTVMFFIInt) + << "Sixth argument is expected to be the stride_w"; // stride_w + ICHECK_EQ(args[6].type_index, kTVMFFIDLTensorPtr) << "Seventh argument is expected to be the output tensor"; // output - auto* act_flat = static_cast(args[0].v_handle); - auto* wgt_flat = static_cast(args[1].v_handle); - auto* out_flat = static_cast(args[6].v_handle); + auto* act_flat = static_cast(args[0].v_ptr); + auto* wgt_flat = static_cast(args[1].v_ptr); + auto* out_flat = static_cast(args[6].v_ptr); // Temporary assertion until multiple batches are supported ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; diff --git a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc index 99f7c245f557..30cba60cf1a8 100644 --- a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc @@ -20,13 +20,12 @@ #include #include #include -#include +#include #include #include "conv2d.h" -extern "C" int conv2d_packed_quant(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle); +extern "C" int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val); namespace tvm { namespace runtime { @@ -230,30 +229,38 @@ void conv_layer_int8_hvx_whole(DLTensor& cr_out, const DLTensor& cr_act, // NOL } // namespace runtime } // namespace tvm -int conv2d_packed_quant(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle) { +int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; ICHECK_EQ(num_args, 13) << "Unexpected number of arguments"; - ICHECK_EQ(type_codes[0], kTVMDLTensorHandle) + ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(type_codes[1], kTVMDLTensorHandle) + ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(type_codes[2], kDLFloat) << "Third argument is expected to be the activation scale"; - ICHECK_EQ(type_codes[3], kDLInt) << "Fourth argument is expected to be the activation zero point"; - ICHECK_EQ(type_codes[4], kDLFloat) << "Fifth argument is expected to be the weight scale"; - ICHECK_EQ(type_codes[5], kDLInt) << "Sixth argument is expected to be the weight zero point"; - ICHECK_EQ(type_codes[6], kDLFloat) << "Seventh argument is expected to be the output scale"; - ICHECK_EQ(type_codes[7], kDLInt) << "Eigth argument is expected to be the output zero point"; - ICHECK_EQ(type_codes[8], kDLInt) << "Nineth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(type_codes[9], kDLInt) << "Tenth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(type_codes[10], kDLInt) << "Eleventh argument is expected to be fixed final scale"; - ICHECK_EQ(type_codes[11], kDLInt) << "Twelfth argument is expected to be scale factor"; - ICHECK_EQ(type_codes[12], kTVMDLTensorHandle) + ICHECK_EQ(args[2].type_index, kTVMFFIFloat) + << "Third argument is expected to be the activation scale"; + ICHECK_EQ(args[3].type_index, kTVMFFIInt) + << "Fourth argument is expected to be the activation zero point"; + ICHECK_EQ(args[4].type_index, kTVMFFIFloat) + << "Fifth argument is expected to be the weight scale"; + ICHECK_EQ(args[5].type_index, kTVMFFIInt) + << "Sixth argument is expected to be the weight zero point"; + ICHECK_EQ(args[6].type_index, kTVMFFIFloat) + << "Seventh argument is expected to be the output scale"; + ICHECK_EQ(args[7].type_index, kTVMFFIInt) + << "Eigth argument is expected to be the output zero point"; + ICHECK_EQ(args[8].type_index, kTVMFFIInt) + << "Nineth argument is expected to be the stride_h"; // stride_h + ICHECK_EQ(args[9].type_index, kTVMFFIInt) + << "Tenth argument is expected to be the stride_w"; // stride_w + ICHECK_EQ(args[10].type_index, kTVMFFIInt) + << "Eleventh argument is expected to be fixed final scale"; + ICHECK_EQ(args[11].type_index, kTVMFFIInt) << "Twelfth argument is expected to be scale factor"; + ICHECK_EQ(args[12].type_index, kTVMFFIDLTensorPtr) << "Thirteenth argument is expected to be the output tensor"; // output - auto* act_flat = static_cast(args[0].v_handle); - auto* wgt_flat = static_cast(args[1].v_handle); - auto* out_flat = static_cast(args[12].v_handle); + auto* act_flat = static_cast(args[0].v_ptr); + auto* wgt_flat = static_cast(args[1].v_ptr); + auto* out_flat = static_cast(args[12].v_ptr); // Temporary assertion until multiple batches are supported ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 265e5bb12e57..0f71f7265024 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -21,7 +21,7 @@ * \file hexagon_session.cc */ -#include +#include extern "C" { #include @@ -109,7 +109,7 @@ class HexagonTransportChannel : public RPCChannel { remote_handle64 _handle = AEE_EUNKNOWN; }; -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 78d65fb8deeb..d9a9d007b090 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -27,9 +27,9 @@ extern "C" { #include #include +#include #include #include -#include #include #include @@ -329,14 +329,14 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); -TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); @@ -354,7 +354,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index a98abe634e8b..2301ffc13d17 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -32,8 +32,8 @@ #include "../../hexagon_common.h" #include "../../profiler/prof_utils.h" #include "hexagon_sim_proto.h" +#include "tvm/ffi/function.h" #include "tvm/runtime/packed_func.h" -#include "tvm/runtime/registry.h" namespace tvm { namespace runtime { @@ -332,14 +332,14 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); -TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); @@ -357,7 +357,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 7366371b491a..3211b8d0472f 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include // POSIX includes #include #include @@ -1370,7 +1370,7 @@ std::optional SimulatorRPCChannel::to_nullptr(const detail::Mayb .Default(std::nullopt); } -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index f580a6d667f1..18f973daf159 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -25,8 +25,8 @@ #include #include +#include #include -#include #include #include diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index ccc0b3193b87..60ce95e2369b 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_ +#include #include -#include #include #include diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 2d4164ce4425..45e83f33e2da 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -121,7 +121,7 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int li if (filename) { // Stack frames for TVM FFI if (strstr(filename, "include/tvm/runtime/packed_func.h") || - strstr(filename, "include/tvm/runtime/registry.h") || + strstr(filename, "include/tvm/ffi/function.h") || strstr(filename, "src/runtime/c_runtime_api.cc")) { return true; } diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index 2b5c217c72be..b6c2a098d474 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -21,8 +21,8 @@ * \file tvm/runtime/memory/memory_manager.cc * \brief Allocate and manage memory for the runtime. */ +#include #include -#include #include #include @@ -264,7 +264,7 @@ void Allocator::Clear() { // Pooled allocator will override this method. } -TVM_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); } // namespace memory } // namespace runtime diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index c415468088ed..b93db5a19c6d 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -35,8 +35,6 @@ #include #include -#include "runtime_base.h" - namespace tvm { namespace runtime { diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index ab383732ea8c..e57907e06ecc 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -30,7 +30,7 @@ #import #import #import -#include +#include #include #include #include diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 83f2c38a2bd5..46824b1600ee 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -21,8 +21,8 @@ * \file metal_device_api.mm */ #include +#include #include -#include #include "metal_common.h" namespace tvm { @@ -362,12 +362,12 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = MetalWorkspace::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); @@ -403,7 +403,7 @@ virtual void Stop() { TVM_REGISTER_OBJECT_TYPE(MetalTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { return Timer(make_object(dev)); }); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 36062cae39c5..f7c59156cb6a 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -22,8 +22,8 @@ */ #include "metal_module.h" #include +#include #include -#include #include #include #include @@ -287,7 +287,7 @@ Module MetalModuleCreate(std::unordered_map smap, return Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.create_metal_module") +TVM_FFI_REGISTER_GLOBAL("runtime.module.create_metal_module") .set_body_typed([](Map smap, std::string fmap_json, std::string fmt, std::string source) { std::istringstream stream(fmap_json); @@ -317,6 +317,6 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(smap, fmap, fmt, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 727e2d650518..ccfd3d079280 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -30,7 +30,7 @@ #include #include -#include +#include #include #include diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 5ba2248f7627..d16239079c67 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -21,9 +21,9 @@ * \file module.cc * \brief TVM module system */ +#include #include #include -#include #include #include @@ -166,52 +166,52 @@ bool RuntimeEnabled(const String& target_str) { return tvm::ffi::Function::GetGlobal(f_name).has_value(); } -TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); +TVM_FFI_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); -TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { return mod->GetSource(fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { return static_cast(mod->imports().size()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { return mod->imports().at(index); }); -TVM_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { mod->ClearImports(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { return mod->GetFormat(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); -TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleSaveToFile") .set_body_typed([](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { return mod->GetPropertyMask(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") .set_body_typed([](Module mod, String name, bool query_imports) { return mod->ImplementsFunction(std::move(name), query_imports); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFunction") .set_body_typed([](Module mod, String name, bool query_imports) { return mod->GetFunction(name, query_imports); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImport").set_body_typed([](Module mod, Module other) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImport").set_body_typed([](Module mod, Module other) { mod->Import(other); }); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index d30c29689963..2bf56e876164 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -21,13 +21,12 @@ * \file ndarray.cc * \brief NDArray container infratructure. */ -#include +#include +#include #include #include #include -#include -#include "runtime_base.h" #include "tvm/runtime/data_type.h" namespace tvm { @@ -73,10 +72,11 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); } -void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { +void NDArray::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; - ICHECK(IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; DLTensor to; to.data = const_cast(data); @@ -87,9 +87,9 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { to.strides = nullptr; to.byte_offset = 0; - DeviceAPI::Get(handle->device)->CopyDataFromTo(const_cast(handle), &to, nullptr); + DeviceAPI::Get(handle->device)->CopyDataFromTo(const_cast(handle), &to, stream); // Synchronize in case data become unavailable later. - DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); + DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional mem_scope) { @@ -106,17 +106,6 @@ NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional< return ffi::NDArray::FromNDAlloc(DeviceAPIAlloc(), shape, dtype, dev, mem_scope); } -struct NDArray::Internal { - // Implementation of API function - static DLTensor* MoveToFFIHandle(NDArray arr) { - DLTensor* handle = NDArray::FFIGetHandle(arr); - // move and discard as handle is already obtained in FFIGetHandle - ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(arr)); - return handle; - } - static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); } -}; - NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_byte_offset) const { ICHECK(data_ != nullptr); @@ -178,7 +167,7 @@ NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, void NDArray::CopyToBytes(void* data, size_t nbytes) const { ICHECK(data != nullptr); ICHECK(data_ != nullptr); - ArrayCopyToBytes(get_mutable(), data, nbytes); + NDArray::CopyToBytes(get_mutable(), data, nbytes); } void NDArray::CopyFromBytes(const void* data, size_t nbytes) { @@ -222,79 +211,19 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str using namespace tvm::runtime; -int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { - API_BEGIN(); - *out_tindex = - tvm::ffi::details::ObjectUnsafe::GetHeader(TVMArrayHandleToObjectHandle(handle))->type_index; - API_END(); -} - -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { - API_BEGIN(); - DLDataType dtype; - dtype.code = static_cast(dtype_code); - dtype.bits = static_cast(dtype_bits); - dtype.lanes = static_cast(dtype_lanes); - tvm::Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - auto ndarray = NDArray::Empty(tvm::ffi::Shape(shape, shape + ndim), dtype, dev); - - *out = NDArray::Internal::MoveToFFIHandle(ndarray); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty); -TVM_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView); -int TVMArrayFree(TVMArrayHandle handle) { - API_BEGIN(); - NDArray::Internal::FFIDecRef(handle); - API_END(); -} - -int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) { - API_BEGIN(); - NDArray::CopyFromTo(from, to, stream); - API_END(); -} - -int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { - API_BEGIN(); - *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); - API_END(); -} - -int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { - API_BEGIN(); - *out = static_cast(TVMArrayHandleToObjectHandle(from))->ToDLPack(); - API_END(); -} - -int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { - API_BEGIN(); - ArrayCopyFromBytes(handle, data, nbytes); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyFromBytes") +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromBytes") .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { ArrayCopyFromBytes(arr, data, nbytes); }); -int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { - API_BEGIN(); - ArrayCopyToBytes(handle, data, nbytes); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyToBytes") +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyToBytes") .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { - ArrayCopyToBytes(arr, data, nbytes); + NDArray::CopyToBytes(arr, data, nbytes); }); -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyFromTo").set_body_typed([](DLTensor* from, DLTensor* to) { - NDArray::CopyFromTo(from, to); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromTo") + .set_body_typed([](DLTensor* from, DLTensor* to) { NDArray::CopyFromTo(from, to); }); diff --git a/src/runtime/object.cc b/src/runtime/object.cc deleted file mode 100644 index 095eee5f5e6b..000000000000 --- a/src/runtime/object.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/runtime/object.cc - * \brief Object type management system. - */ -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "object_internal.h" -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) { - return static_cast(ObjectPtrHash()(obj)); -}); - -} // namespace runtime -} // namespace tvm - -int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { - API_BEGIN(); - ICHECK(obj != nullptr); - out_tindex[0] = static_cast(obj)->type_index(); - API_END(); -} - -int TVMObjectRetain(TVMObjectHandle obj) { - API_BEGIN(); - tvm::runtime::ObjectInternal::ObjectRetain(obj); - API_END(); -} - -int TVMObjectFree(TVMObjectHandle obj) { - API_BEGIN(); - tvm::runtime::ObjectInternal::ObjectFree(obj); - API_END(); -} - -int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { - API_BEGIN(); - *is_derived = [&]() { - if (child_type_index == parent_type_index) return true; - if (child_type_index < parent_type_index) return false; - const TVMFFITypeInfo* child_type_info = TVMFFIGetTypeInfo(child_type_index); - const TVMFFITypeInfo* parent_type_info = TVMFFIGetTypeInfo(parent_type_index); - return (child_type_info->type_depth > parent_type_info->type_depth && - child_type_info->type_acenstors[parent_type_info->type_depth] == - static_cast(parent_type_index)); - }(); - API_END(); -} - -int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { - API_BEGIN(); - out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); - API_END(); -} - -int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { - API_BEGIN(); - auto key = tvm::runtime::Object::TypeIndex2Key(tindex); - *out_type_key = static_cast(malloc(key.size() + 1)); - strncpy(*out_type_key, key.c_str(), key.size() + 1); - API_END(); -} diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h deleted file mode 100644 index 40e4e2fb4855..000000000000 --- a/src/runtime/object_internal.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/runtime/object_internal.h - * \brief Expose a few functions for CFFI purposes. - * This file is not intended to be used - */ -#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ -#define TVM_RUNTIME_OBJECT_INTERNAL_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Internal object namespace to expose - * certain util functions for FFI. - */ -class ObjectInternal { - public: - /*! - * \brief Retain an object handle. - */ - static void ObjectRetain(TVMObjectHandle obj) { - if (obj != nullptr) { - // static_cast(obj)->IncRef(); - tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(obj); - } - } - - /*! - * \brief Free an object handle. - */ - static void ObjectFree(TVMObjectHandle obj) { - if (obj != nullptr) { - // static_cast(obj)->DecRef(); - tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(obj); - } - } - /*! - * \brief Check of obj derives from the type indicated by type index. - * \param obj The original object. - * \param type_index The type index of interest. - * \return The derivation checking result. - */ - // static bool DerivedFrom(const Object* obj, uint32_t type_index) { - // return obj->DerivedFrom(type_index); - // } - /*! - * \brief Expose TypeKey2Index - * \param type_key The original type key. - * \return the corresponding index. - */ - static uint32_t ObjectTypeKey2Index(const std::string& type_key) { - int32_t type_index; - TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - return static_cast(type_index); - } - /*! - * \brief Convert ModuleHandle to module node pointer. - * \param handle The module handle. - * \return the corresponding module node pointer. - */ - static ModuleNode* GetModuleNode(TVMModuleHandle handle) { - // NOTE: we will need to convert to Object - // then to ModuleNode in order to get the correct - // address translation - return static_cast(static_cast(handle)); - } -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 7288b65b0a24..91dad2af82c4 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ -#include +#include #include #include #include diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index d6eaec6c0e04..000e9a94599e 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -22,8 +22,8 @@ */ #include #include +#include #include -#include #include @@ -760,7 +760,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.opencl.alloc_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -788,7 +788,7 @@ TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") String("global.texture")); }); -TVM_REGISTER_GLOBAL("device_api.opencl.free_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.opencl.free_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -803,14 +803,15 @@ TVM_REGISTER_GLOBAL("device_api.opencl.free_nd") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.opencl").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.opencl") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global(); + *rv = static_cast(ptr); + }); TVM_REGISTER_OBJECT_TYPE(OpenCLTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.opencl").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.opencl").set_body_typed([](Device dev) { return Timer(make_object(dev)); }); @@ -893,7 +894,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { } }; -TVM_REGISTER_GLOBAL("DeviceAllocator.opencl") +TVM_FFI_REGISTER_GLOBAL("DeviceAllocator.opencl") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Allocator* alloc = new OpenCLPooledAllocator(); *rv = static_cast(alloc); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 90cdcb48bf96..8e8ee5a43b78 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -23,7 +23,7 @@ #include "opencl_module.h" #include -#include +#include #include #include @@ -146,7 +146,7 @@ ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, for (size_t i = 0; i < info.arg_types.size(); ++i) { DLDataType t = info.arg_types[i]; ICHECK_EQ(t.lanes, 1U); - if (t.code == kTVMOpaqueHandle) { + if (t.code == kDLOpaqueHandle) { // specially store pointer type size in OpenCL driver arg_size[i] = sizeof(void*); } else { @@ -389,10 +389,10 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 28e02a4e3749..7d281694decb 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index ec000524fa00..6fbbb05b7d58 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -31,7 +31,7 @@ #ifndef TVM_RUNTIME_PACK_ARGS_H_ #define TVM_RUNTIME_PACK_ARGS_H_ -#include +#include #include #include @@ -143,7 +143,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { } else if (t.code == kDLFloat) { if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 32U) return FLOAT64_TO_FLOAT32; - } else if (t.code == kTVMOpaqueHandle) { + } else if (t.code == kDLOpaqueHandle) { return HANDLE_TO_HANDLE; } LOG(FATAL) << "Cannot handle " << t << " as device function argument"; @@ -240,7 +240,6 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector pack_(num_args); int32_t* pack = reinterpret_cast(pack_.data()); int32_t* ptr = pack; - static_assert(sizeof(TVMValue) == 8, "invariant"); static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant"); const TVMFFIAny* raw_args = reinterpret_cast(args.data()); @@ -317,13 +316,13 @@ inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_ty inline size_t NumBufferArgs(const std::vector& arg_types) { size_t base = arg_types.size(); for (size_t i = 0; i < arg_types.size(); ++i) { - if (arg_types[i].code != kTVMOpaqueHandle) { + if (arg_types[i].code != kDLOpaqueHandle) { base = i; break; } } for (size_t i = base; i < arg_types.size(); ++i) { - ICHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized"; + ICHECK(arg_types[i].code != kDLOpaqueHandle) << "Device function need to be organized"; } return base; } diff --git a/src/runtime/packed_func.cc b/src/runtime/packed_func.cc index 63ec7bbc7d47..38146227f9dd 100644 --- a/src/runtime/packed_func.cc +++ b/src/runtime/packed_func.cc @@ -20,8 +20,8 @@ * \file src/runtime/packed_func.cc * \brief Implementation of non-inlinable ffi::Function pieces. */ +#include #include -#include namespace tvm { namespace runtime { diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index bb67f6f6eb7b..2a12fba0b02d 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -43,11 +43,11 @@ namespace runtime { class DefaultTimerNode : public TimerNode { public: virtual void Start() { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + DeviceAPI::Get(device_)->StreamSync(device_, nullptr); start_ = std::chrono::high_resolution_clock::now(); } virtual void Stop() { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + DeviceAPI::Get(device_)->StreamSync(device_, nullptr); duration_ = std::chrono::high_resolution_clock::now() - start_; } virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } @@ -84,7 +84,7 @@ class CPUTimerNode : public TimerNode { }; TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { return Timer(make_object()); }); @@ -115,7 +115,7 @@ Timer Timer::Start(Device dev) { } } -TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); +TVM_FFI_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); namespace profiling { @@ -788,13 +788,15 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); -TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); -TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); -TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { + return n->AsCSV(); +}); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); }); -TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); -TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { return DeviceWrapper(dev); }); @@ -843,7 +845,7 @@ ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type }); } -TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") .set_body_typed)>([](Module mod, String func_name, int device_type, int device_id, @@ -924,26 +926,26 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re return ffi::Function::FromPacked(ftimer); } -TVM_REGISTER_GLOBAL("runtime.profiling.Report") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Report") .set_body_typed([](Array> calls, Map> device_metrics, Map configuration) { return Report(calls, device_metrics, configuration); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { return ObjectRef(make_object(count)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) { return ObjectRef(make_object(percent)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) { return ObjectRef(make_object(duration)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) { return ObjectRef(make_object(ratio)); }); diff --git a/src/runtime/regex.cc b/src/runtime/regex.cc index 8b4df9e69395..a91bf479ce4b 100644 --- a/src/runtime/regex.cc +++ b/src/runtime/regex.cc @@ -24,17 +24,18 @@ #include "./regex.h" -#include +#include namespace tvm { namespace runtime { bool regex_match(const std::string& match_against, const std::string& regex_pattern) { const auto regex_match_func = tvm::ffi::Function::GetGlobal("tvm.runtime.regex_match"); - CHECK(regex_match_func.has_value()) - << "RuntimeError: " - << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " - << "This can occur if the TVM Python library has not yet been imported."; + if (!regex_match_func.has_value()) { + TVM_FFI_THROW(RuntimeError) + << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " + << "This can occur if the TVM Python library has not yet been imported."; + } return (*regex_match_func)(regex_pattern, match_against).cast(); } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc deleted file mode 100644 index 443b625d1cb4..000000000000 --- a/src/runtime/registry.cc +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file registry.cc - * \brief The global registry of packed function. - */ -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Execution environment specific API registry. - * - * This registry stores C API function pointers about - * execution environment(e.g. python) specific API function that - * we need for specific low-level handling(e.g. signal checking). - * - * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the ffi::Function FFI when possible - * in other cases. - */ -class EnvCAPIRegistry { - public: - /*! - * \brief Callback to check if signals have been sent to the process and - * if so invoke the registered signal handler in the frontend environment. - * - * When running TVM in another language (Python), the signal handler - * may not be immediately executed, but instead the signal is marked - * in the interpreter state (to ensure non-blocking of the signal handler). - * - * \return 0 if no error happens, -1 if error happens. - */ - typedef int (*F_PyErr_CheckSignals)(); - - /*! \brief Callback to increment/decrement the python ref count */ - typedef void (*F_Py_IncDefRef)(void*); - - // NOTE: the following functions are only registered in a python - // environment. - /*! - * \brief PyErr_CheckSignal function - */ - F_PyErr_CheckSignals pyerr_check_signals = nullptr; - - /*! - * \brief Py_IncRef function - */ - F_Py_IncDefRef py_inc_ref = nullptr; - - /*! - * \brief Py_IncRef function - */ - F_Py_IncDefRef py_dec_ref = nullptr; - - /*! - \brief PyGILState_Ensure function - */ - void* (*py_gil_state_ensure)() = nullptr; - - /*! - \brief PyGILState_Release function - */ - void (*py_gil_state_release)(void*) = nullptr; - - static EnvCAPIRegistry* Global() { - static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); - return inst; - } - - // register environment(e.g. python) specific api functions - void Register(const String& symbol_name, void* fptr) { - if (symbol_name == "PyErr_CheckSignals") { - Update(symbol_name, &pyerr_check_signals, fptr); - } else if (symbol_name == "Py_IncRef") { - Update(symbol_name, &py_inc_ref, fptr); - } else if (symbol_name == "Py_DecRef") { - Update(symbol_name, &py_dec_ref, fptr); - } else if (symbol_name == "PyGILState_Ensure") { - Update(symbol_name, &py_gil_state_ensure, fptr); - } else if (symbol_name == "PyGILState_Release") { - Update(symbol_name, &py_gil_state_release, fptr); - } else { - LOG(FATAL) << "Unknown env API " << symbol_name; - } - } - - // implementation of tvm::runtime::EnvCheckSignals - void CheckSignals() { - // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr) { - // The C++ env comes without gil, so we need to grab gil here - WithGIL context(this); - if ((*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - throw EnvErrorAlreadySet(); - } - } - } - - void IncRef(void* python_obj) { - WithGIL context(this); - ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " - << "but Py_IncRef wasn't registered"; - (*py_inc_ref)(python_obj); - } - - void DecRef(void* python_obj) { - WithGIL context(this); - ICHECK(py_dec_ref) << "Attempted to call Py_DefRef through EnvCAPIRegistry, " - << "but Py_DefRef wasn't registered"; - (*py_dec_ref)(python_obj); - } - - private: - // update the internal API table - template - void Update(const String& symbol_name, FType* target, void* ptr) { - FType ptr_casted = reinterpret_cast(ptr); - if (target[0] != nullptr && target[0] != ptr_casted) { - LOG(WARNING) << "tvm.runtime.RegisterEnvCAPI overrides an existing function " << symbol_name; - } - target[0] = ptr_casted; - } - - struct WithGIL { - explicit WithGIL(EnvCAPIRegistry* self) : self(self) { - ICHECK(self->py_gil_state_ensure) << "Attempted to acquire GIL through EnvCAPIRegistry, " - << "but PyGILState_Ensure wasn't registered"; - ICHECK(self->py_gil_state_release) << "Attempted to acquire GIL through EnvCAPIRegistry, " - << "but PyGILState_Release wasn't registered"; - gil_state = self->py_gil_state_ensure(); - } - ~WithGIL() { - if (self && gil_state) { - self->py_gil_state_release(gil_state); - } - } - WithGIL(const WithGIL&) = delete; - WithGIL(WithGIL&&) = delete; - WithGIL& operator=(const WithGIL&) = delete; - WithGIL& operator=(WithGIL&&) = delete; - - EnvCAPIRegistry* self = nullptr; - void* gil_state = nullptr; - }; -}; - -void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); } - -WrappedPythonObject::WrappedPythonObject(void* python_obj) : python_obj_(python_obj) { - if (python_obj_) { - EnvCAPIRegistry::Global()->IncRef(python_obj_); - } -} - -WrappedPythonObject::~WrappedPythonObject() { - if (python_obj_) { - EnvCAPIRegistry::Global()->DecRef(python_obj_); - } -} - -WrappedPythonObject::WrappedPythonObject(WrappedPythonObject&& other) : python_obj_(nullptr) { - std::swap(python_obj_, other.python_obj_); -} -WrappedPythonObject& WrappedPythonObject::operator=(WrappedPythonObject&& other) { - std::swap(python_obj_, other.python_obj_); - return *this; -} - -WrappedPythonObject::WrappedPythonObject(const WrappedPythonObject& other) - : WrappedPythonObject(other.python_obj_) {} -WrappedPythonObject& WrappedPythonObject::operator=(const WrappedPythonObject& other) { - return *this = WrappedPythonObject(other); -} -WrappedPythonObject& WrappedPythonObject::operator=(std::nullptr_t) { - return *this = WrappedPythonObject(nullptr); -} - -} // namespace runtime -} // namespace tvm - -/*! \brief entry to easily hold returning information */ -struct TVMFuncThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; - -int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { - API_BEGIN(); - using tvm::runtime::GetRef; - tvm::ffi::Function::SetGlobal( - name, GetRef(static_cast(f)), override != 0); - API_END(); -} - -int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { - API_BEGIN(); - const auto fp = tvm::ffi::Function::GetGlobal(name); - if (fp.has_value()) { - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(tvm::ffi::Any(*fp)); - *out = val.v_obj; - } else { - *out = nullptr; - } - API_END(); -} - -int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { - API_BEGIN(); - TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); - ret->ret_vec_str = tvm::ffi::Function::ListGlobalNames(); - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); - API_END(); -} - -int TVMFuncRemoveGlobal(const char* name) { - API_BEGIN(); - tvm::ffi::Function::RemoveGlobal(name); - API_END(); -} - -int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) { - API_BEGIN(); - tvm::runtime::EnvCAPIRegistry::Global()->Register(name, ptr); - API_END(); -} diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 23aee70ff6c5..f62bb14608fa 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -29,13 +30,10 @@ #include #include #include -#include #include #include #include -#include "../runtime_base.h" - namespace tvm { namespace runtime { namespace relax_vm { @@ -66,7 +64,7 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); } -TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); /*! * \brief Builtin match R.Prim function. @@ -106,7 +104,7 @@ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t } } -TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); /*! * \brief Builtin match shape function. @@ -157,7 +155,7 @@ void MatchShape(ffi::PackedArgs args, Any* rv) { } } -TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body_packed(MatchShape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_shape").set_body_packed(MatchShape); /*! * \brief Builtin make prim value function. @@ -181,7 +179,7 @@ int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { } } -TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); /*! * \brief Builtin make shape function. @@ -212,7 +210,7 @@ void MakeShape(ffi::PackedArgs args, Any* rv) { *rv = ffi::Shape(std::move(shape)); } -TVM_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); /*! * \brief Builtin function to check if arg is Tensor(dtype, ndim) @@ -252,7 +250,7 @@ void CheckTensorInfo(ffi::PackedArgs args, Any* rv) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorInfo); /*! * \brief Builtin function to check if arg is Shape(ndim) @@ -272,7 +270,7 @@ void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); /*! * \brief Builtin function to check if arg is PrimValue(dtype) @@ -299,7 +297,7 @@ void CheckPrimValueInfo(AnyView arg, DataType dtype, Optional err_ctx) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); /*! * \brief Builtin function to check if arg is Tuple with size elements. @@ -317,7 +315,7 @@ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { << " but get a Tuple with " << ptr->size() << " elements."; } -TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); /*! * \brief Builtin function to check if arg is a callable function. @@ -331,7 +329,7 @@ void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { << arg->GetTypeKey(); } -TVM_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); //------------------------------------------------- // Storage management. @@ -356,61 +354,65 @@ Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_inde return Storage(buffer, alloc); } -TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); -TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); //------------------------------------------------- // Closure function handling, calling convention //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.make_closure").set_body_packed([](ffi::PackedArgs args, Any* rv) { - VMClosure clo = args[0].cast(); - std::vector saved_args; - saved_args.resize(args.size() - 1); - for (size_t i = 0; i < saved_args.size(); ++i) { - saved_args[i] = args[i + 1]; - } - auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); - *rv = VMClosure(clo->func_name, impl); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_closure") + .set_body_packed([](ffi::PackedArgs args, Any* rv) { + VMClosure clo = args[0].cast(); + std::vector saved_args; + saved_args.resize(args.size() - 1); + for (size_t i = 0; i < saved_args.size(); ++i) { + saved_args[i] = args[i + 1]; + } + auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); + *rv = VMClosure(clo->func_name, impl); + }); -TVM_REGISTER_GLOBAL("vm.builtin.invoke_closure").set_body_packed([](ffi::PackedArgs args, Any* rv) { - // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments - VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); - ObjectRef vm_closure = args[1].cast(); - vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_closure") + .set_body_packed([](ffi::PackedArgs args, Any* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef vm_closure = args[1].cast(); + vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); + }); -TVM_REGISTER_GLOBAL("vm.builtin.call_tir_dyn").set_body_packed([](ffi::PackedArgs args, Any* rv) { - ffi::Function func = args[0].cast(); - ffi::Shape to_unpack = args[args.size() - 1].cast(); - size_t num_tensor_args = args.size() - 2; +TVM_FFI_REGISTER_GLOBAL("vm.builtin.call_tir_dyn") + .set_body_packed([](ffi::PackedArgs args, Any* rv) { + ffi::Function func = args[0].cast(); + ffi::Shape to_unpack = args[args.size() - 1].cast(); + size_t num_tensor_args = args.size() - 2; - std::vector packed_args(num_tensor_args + to_unpack.size()); - std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); + std::vector packed_args(num_tensor_args + to_unpack.size()); + std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); - for (size_t i = 0; i < to_unpack.size(); ++i) { - packed_args[i + num_tensor_args] = to_unpack[i]; - } - func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); -}); + for (size_t i = 0; i < to_unpack.size(); ++i) { + packed_args[i + num_tensor_args] = to_unpack[i]; + } + func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); + }); //------------------------------------- // Builtin runtime operators. //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); -TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](Any a) -> Any { return a; }); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](Any a) -> Any { return a; }); -TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ffi::Shape new_shape) { - return data.CreateView(new_shape, data->dtype); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.reshape") + .set_body_typed([](NDArray data, ffi::Shape new_shape) { + return data.CreateView(new_shape, data->dtype); + }); -TVM_REGISTER_GLOBAL("vm.builtin.null_value").set_body_typed([]() -> std::nullptr_t { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.null_value").set_body_typed([]() -> std::nullptr_t { return nullptr; }); -TVM_REGISTER_GLOBAL("vm.builtin.to_device") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.to_device") .set_body_typed([](NDArray data, int dev_type, int dev_id) { Device dst_device = {(DLDeviceType)dev_type, dev_id}; return data.CopyTo(dst_device); @@ -459,13 +461,13 @@ bool ReadIfCond(AnyView cond) { return result != 0; } -TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); //------------------------------------- // Debugging API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") .set_body_packed([](ffi::PackedArgs args, Any* rv) -> void { ICHECK_GE(args.size(), 3); int num_args = args.size() - 3; @@ -491,16 +493,15 @@ TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") //------------------------------------- // Data structure API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](Array arr, int64_t index) { - return arr[index]; -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_getitem") + .set_body_typed([](Array arr, int64_t index) { return arr[index]; }); -TVM_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") .set_body_typed([](const ffi::ArrayObj* arr, int64_t index) { const_cast(arr)->SetItem(index, nullptr); }); -TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body_packed([](ffi::PackedArgs args, Any* rv) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body_packed([](ffi::PackedArgs args, Any* rv) { Array arr; for (int i = 0; i < args.size(); ++i) { arr.push_back(args[i]); @@ -508,7 +509,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body_packed([](ffi::PackedArgs *rv = arr; }); -TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { NDArray arr = data; if (data->device.device_type != kDLCPU) { arr = data.CopyTo(DLDevice{kDLCPU, 0}); @@ -542,7 +543,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ffi::Shape(out_shape); }); -TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { if (data->byte_offset == 0) { return data; } @@ -603,26 +604,26 @@ TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMF int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, int arg_offset) { using namespace tvm::runtime; - API_BEGIN(); + TVM_FFI_SAFE_CALL_BEGIN(); auto* list = static_cast(anylist); args[arg_offset] = list[index]; - API_END(); + TVM_FFI_SAFE_CALL_END(); } int TVMBackendAnyListResetItem(void* anylist, int index) { using namespace tvm::runtime; - API_BEGIN(); + TVM_FFI_SAFE_CALL_BEGIN(); auto* list = static_cast(anylist); list[index] = nullptr; - API_END(); + TVM_FFI_SAFE_CALL_END(); } int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, int ret_offset) { using namespace tvm::runtime; - API_BEGIN(); + TVM_FFI_SAFE_CALL_BEGIN(); auto* list = static_cast(anylist); list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(args[ret_offset])); - API_END(); + TVM_FFI_SAFE_CALL_END(); } } // extern "C" diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 501f7925de2d..2fdf514b0e5d 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include "../../../support/utils.h" @@ -241,7 +241,7 @@ class CUDAGraphExtension : public VMExtension { } }; -TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() == 5 || args.size() == 4); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); @@ -256,7 +256,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); }); -TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK_EQ(args.size(), 3); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index ff6f99c4d262..52a0588be35c 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -210,7 +210,7 @@ Module VMExecutable::LoadFromBinary(void* stream) { return Module(exec); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") .set_body_typed(VMExecutable::LoadFromBinary); Module VMExecutable::LoadFromFile(const String& file_name) { @@ -221,7 +221,7 @@ Module VMExecutable::LoadFromFile(const String& file_name) { return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") .set_body_typed(VMExecutable::LoadFromFile); void VMFuncInfo::Save(dmlc::Stream* strm) const { @@ -354,7 +354,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { this->constants.push_back(cell); } else { LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " - << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + << ffi::TypeIndexToTypeKey(constant_type) << " when loading the VM constant pool."; } } } @@ -557,7 +557,7 @@ String VMExecutable::AsPython() const { return String(os.str()); } -TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); +TVM_FFI_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); } // namespace relax_vm } // namespace runtime diff --git a/src/runtime/relax_vm/hexagon/builtin.cc b/src/runtime/relax_vm/hexagon/builtin.cc index 3cfa4db71744..d2d05a0e8256 100644 --- a/src/runtime/relax_vm/hexagon/builtin.cc +++ b/src/runtime/relax_vm/hexagon/builtin.cc @@ -22,9 +22,9 @@ * \brief The hexagon graph related builtin functions for Relax virtual machine. */ +#include #include #include -#include #include #include "../../hexagon/hexagon_device_api.h" @@ -32,7 +32,7 @@ namespace tvm { namespace runtime { namespace relax_vm { -TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") .set_body_typed([](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, bool bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); @@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") CHECK(ret == DMA_SUCCESS); }); -TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") .set_body_typed([](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, [[maybe_unused]] NDArray src_arr, [[maybe_unused]] NDArray dst_arr) { ICHECK(inflight_dma >= 0); diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index f9689d64c647..12f52c0794e9 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -31,13 +31,15 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj); TVM_REGISTER_OBJECT_TYPE(RNNStateObj); // KV State base methods -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method(&KVStateObj::Clear); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence").set_body_method(&KVStateObj::AddSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method(&KVStateObj::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence") + .set_body_method(&KVStateObj::AddSequence); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") .set_body_method(&KVStateObj::RemoveSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence").set_body_method(&KVStateObj::ForkSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence") + .set_body_method(&KVStateObj::ForkSequence); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 3 || args.size() == 4) << "KVState BeginForward only accepts 3 or 4 arguments"; @@ -50,53 +52,53 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") } kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); }); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward").set_body_method(&KVStateObj::EndForward); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward").set_body_method(&KVStateObj::EndForward); // Attention KV Cache methods -TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") .set_body_method(&AttentionKVCacheObj::DisaggPrepareRecv); -TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") .set_body_method(&AttentionKVCacheObj::DisaggMarkSend); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") .set_body_method(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") .set_body_method(&AttentionKVCacheObj::Empty); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") .set_body_method(&AttentionKVCacheObj::GetNumAvailablePages); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") .set_body_method(&AttentionKVCacheObj::GetTotalSequenceLength); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") .set_body_method(&AttentionKVCacheObj::GetQueryPositions); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") .set_body_method(&AttentionKVCacheObj::DebugGetKV); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") .set_body_method(&AttentionKVCacheObj::DebugGetKVMLA); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray qkv_data, NDArray o_data) { kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), std::nullopt, std::move(o_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(o_data), std::move(lse_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, NDArray o_data, NDArray lse_data) { kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), std::move(lse_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); return kv_cache; }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") .set_body_typed([](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, NDArray o_cross_attn, NDArray lse_cross_attn) { return kv_cache->MergeAttnOutputInplace(std::move(o_self_attn), std::move(lse_self_attn), @@ -104,13 +106,13 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") }); // RNN State methods -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_set") .set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) { state->Set(layer_id, state_id, data); return state; }); -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get").set_body_method(&RNNStateObj::DebugGet); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get").set_body_method(&RNNStateObj::DebugGet); } // namespace relax_vm } // namespace runtime diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 3ddc9f6e0c9a..5800c4e2db93 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -19,12 +19,12 @@ #ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_ #define TVM_RUNTIME_RELAX_VM_KV_STATE_H_ #include +#include #include #include #include #include #include -#include namespace tvm { namespace runtime { diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 45d8904d1932..8abeddcf18dc 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -259,7 +259,7 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj); //------------------------------------------------- // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") .set_body_typed(AttentionKVCacheLegacy::Create); AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDArray value) { @@ -267,14 +267,16 @@ AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDAr return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update").set_body_typed(AttentionKVCacheUpdate); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update") + .set_body_typed(AttentionKVCacheUpdate); AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, NDArray value) { cache->Append(value); return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append").set_body_typed(AttentionKVCacheAppend); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append") + .set_body_typed(AttentionKVCacheAppend); AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, NDArray value, int64_t max_cache_size) { @@ -282,7 +284,7 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cac return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") .set_body_typed(AttentionKVCacheWindowOverride); AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache, @@ -293,14 +295,14 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheL return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") .set_body_typed(AttentionKVCacheWindowOverrideWithSinks); NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) { return cache->View(shape); } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 1 || args.size() == 2) << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " @@ -325,7 +327,7 @@ void AttentionKVCacheArrayPopN(Array caches, int64_t n) } } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn") .set_body_typed(AttentionKVCacheArrayPopN); void AttentionKVCacheArrayClear(Array caches) { @@ -334,7 +336,7 @@ void AttentionKVCacheArrayClear(Array caches) { } } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear") .set_body_typed(AttentionKVCacheArrayClear); // NOTE this is a built-in highly related to LM so we put it here. @@ -399,7 +401,7 @@ int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, doubl return data[0].second; } -TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits); int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { ICHECK(prob.IsContiguous()); @@ -494,7 +496,7 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { return sampled_index; } -TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { ICHECK(prob.IsContiguous()); @@ -531,7 +533,8 @@ NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { return new_array; } -TVM_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform").set_body_typed(MultinomialFromUniform); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform") + .set_body_typed(MultinomialFromUniform); // This is an inplace operation. void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { @@ -554,7 +557,8 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty") + .set_body_typed(ApplyRepetitionPenalty); /*! * \brief Apply presence and frequency penalty. This is an inplace operation. @@ -589,7 +593,7 @@ void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty") .set_body_typed(ApplyPresenceAndFrequencyPenalty); // This is an inplace operation. @@ -614,7 +618,7 @@ void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") .set_body_typed(ApplySoftmaxWithTemperature); } // namespace relax_vm diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index ef60f5b870e6..c69b48ccd697 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -39,8 +39,8 @@ #define __STDC_FORMAT_MACROS #endif #include +#include #include -#include #include #include @@ -162,7 +162,7 @@ void CopyNDArrayFromBytes(NDArray param, const void* data, size_t nbytes, NDArray staging_view = staging_buffer->value().CreateView(param.Shape(), param->dtype); staging_view.CopyFromBytes(data, nbytes); param.CopyFrom(staging_view); - TVMSynchronize(device.device_type, device.device_id, nullptr); + DeviceAPI::Get(device)->StreamSync(device, nullptr); } NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load( @@ -266,8 +266,8 @@ class NDArrayCache { Map pool_; }; -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") .set_body_packed([](ffi::PackedArgs args, Any* rv) { CHECK(args.size() == 2 || args.size() == 3); String name = args[0].cast(); @@ -285,14 +285,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") } arr = NDArray::Empty(shape, tensor->dtype, tensor->device); arr.CopyFrom(tensor); - TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); + DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); } NDArrayCache::Update(name, arr, is_override); }); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load); // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. @@ -353,13 +353,15 @@ class ParamModuleNode : public runtime::ModuleNode { Array params_; }; -TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create); -TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache") + .set_body_typed(ParamModuleNode::Create); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") .set_body_typed(ParamModuleNode::CreateByName); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache") + .set_body_typed(ParamModuleNode::GetParams); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") .set_body_typed(ParamModuleNode::GetParamByName); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") .set_body_packed([](ffi::PackedArgs args, Any* rv) { Array names; names.reserve(args.size()); diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index b8444e03cf02..be9cd5955767 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -20,12 +20,12 @@ * \file src/runtime/relax_vm/paged_kv_cache.cc * \brief Runtime paged KV cache object for language models. */ +#include #include #include #include #include #include -#include #include #include @@ -2284,7 +2284,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body_packed([](ffi::PackedArgs args, Any* rv) { // Todo: cuda graph arg CHECK(args.size() == 28 || args.size() == 29) diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index 9f134dcace1d..d431fdb2ae2f 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -464,7 +464,7 @@ TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj); // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_create") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_create") .set_body_typed([](int64_t num_layers, // int64_t reserved_num_seqs, // int64_t max_history, // diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index dcaa1a3726a6..0fef2e4c6d07 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -162,14 +162,7 @@ struct VMFrame { std::vector register_file; /*! \brief Register in caller's frame to put return value */ RegName caller_return_register; - // The following fields are used for ffi::Function call within - // a single function scope. The space is reused across multiple - // packed func calls to increase cache locality and avoid re-allocation - /*! \brief Temporary argument value stack for packed func call. */ - std::vector call_arg_values; /*! \brief Temporary argument tcode stack for packed func call. */ - std::vector call_arg_tcodes; - std::vector call_args; VMFrame(Index pc, Index register_file_size) diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 67991717552e..a5bc3b1a0da5 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -24,10 +24,10 @@ #include #include #include +#include #include #include #include -#include #include "rocm_common.h" @@ -251,15 +251,16 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.rocm_host") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast(ptr); + }); class ROCMTimerNode : public TimerNode { public: @@ -292,11 +293,11 @@ class ROCMTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(ROCMTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) { return Timer(make_object()); }); -TVM_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() { return static_cast(ROCMThreadEntry::ThreadLocal()->stream); }); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 44c7483624e6..2d3ba16de247 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -23,7 +23,7 @@ #include "rocm_module.h" #include -#include +#include #include #include @@ -231,12 +231,12 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 710965d07824..ffe031fadfb4 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -20,9 +20,9 @@ /*! * \file rpc_device_api.cc */ +#include #include #include -#include #include @@ -150,7 +150,7 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { static RPCDeviceAPI inst; DeviceAPI* ptr = &inst; *rv = static_cast(ptr); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index cc7b1db8075f..9e54223a09df 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -23,10 +23,10 @@ */ #include "rpc_endpoint.h" -#include +#include +#include #include #include -#include #include #include @@ -41,7 +41,6 @@ #include "../../support/arena.h" #include "../../support/ring_buffer.h" #include "../../support/utils.h" -#include "../object_internal.h" #include "rpc_local_session.h" namespace tvm { diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 97d62cd586fc..c178db59a230 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -21,7 +21,7 @@ * \file rpc_event_impl.cc * \brief Event driven RPC server implementation. */ -#include +#include #include @@ -44,6 +44,6 @@ ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, }); } -TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); +TVM_FFI_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 38b52181d5d4..1769ed077f6a 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -23,8 +23,8 @@ */ #include "rpc_local_session.h" +#include #include -#include #include #include @@ -146,7 +146,7 @@ DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) { return DeviceAPI::Get(dev, allow_missing); } -TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index a0315604bff5..67faa3329be5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -21,10 +21,10 @@ * \file rpc_module.cc * \brief RPC runtime module. */ +#include #include #include #include -#include #include #include @@ -389,7 +389,7 @@ inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { } } -TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") +TVM_FFI_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, @@ -435,40 +435,40 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") } }); -TVM_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") +TVM_FFI_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); // server function registration. -TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { - parent->Import(child); -}); +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") + .set_body_typed([](Module parent, Module child) { parent->Import(child); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") .set_body_typed([](Module parent, std::string name, bool query_imports) { return parent->GetFunction(name, query_imports); }); // functions to access an RPC module. -TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { +TVM_FFI_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { std::string tkey = sess->type_key(); ICHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }); -TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { +TVM_FFI_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { std::string tkey = parent->type_key(); ICHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }); -TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Module m = args[0].cast(); - std::string tkey = m->type_key(); - ICHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); -}); +TVM_FFI_REGISTER_GLOBAL("rpc.SessTableIndex") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Module m = args[0].cast(); + std::string tkey = m->type_key(); + ICHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); + }); -TVM_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, void* ndarray_handle) -> NDArray { return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index 25472de72777..b9121968137b 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -112,13 +112,14 @@ Module CreatePipeClient(std::vector cmd) { return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::vector cmd; - for (int i = 0; i < args.size(); ++i) { - cmd.push_back(args[i].cast()); - } - *rv = CreatePipeClient(cmd); -}); +TVM_FFI_REGISTER_GLOBAL("rpc.CreatePipeClient") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].cast()); + } + *rv = CreatePipeClient(cmd); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index 823fa232a953..eeb76c2b1512 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -21,7 +21,7 @@ * \file rpc_server_env.cc * \brief Server environment of the RPC. */ -#include +#include #include "../file_utils.h" @@ -35,14 +35,14 @@ std::string RPCGetPath(const std::string& name) { return (*f)(name).cast(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); auto data = args[1].cast(); SaveBinaryToFile(file_name, data); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.download") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.download") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); std::string data; @@ -51,7 +51,7 @@ TVM_REGISTER_GLOBAL("tvm.rpc.server.download") *rv = ffi::Bytes(data); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.remove") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); RemoveFile(file_name); diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index f51117211a82..2564242bdf0f 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,7 +21,7 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ -#include +#include #include @@ -121,7 +121,7 @@ void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) { ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); @@ -129,7 +129,7 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi: *rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4)); }); -TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_int = args[0].as()) { RPCServerLoop(opt_int.value()); } else { @@ -162,7 +162,7 @@ class SimpleSockHandler : public dmlc::Stream { support::TCPSocket sock_; }; -TVM_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) { +TVM_FFI_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) { auto handler = SimpleSockHandler(sockfd); RPCReference::ReturnException(msg.c_str(), &handler); return; diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h deleted file mode 100644 index 3037c8d84ff0..000000000000 --- a/src/runtime/runtime_base.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file runtime_base.h - * \brief Base of all C APIs - */ -#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ -#define TVM_RUNTIME_RUNTIME_BASE_H_ - -#include - -#include - -/*! \brief macro to guard beginning and end section of all functions */ -#define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); - and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() \ - } \ - catch (::tvm::runtime::EnvErrorAlreadySet & _except_) { \ - return -2; \ - } \ - catch (std::exception & _except_) { \ - return TVMAPIHandleException(_except_); \ - } \ - return 0; // NOLINT(*) -/*! - * \brief every function starts with API_BEGIN(); - * and finishes with API_END() or API_END_HANDLE_ERROR - * The finally clause contains procedure to cleanup states when an error happens. - */ -#define API_END_HANDLE_ERROR(Finalize) \ - } \ - catch (::tvm::runtime::EnvErrorAlreadySet & _except_) { \ - return -2; \ - } \ - catch (std::exception & _except_) { \ - Finalize; \ - return TVMAPIHandleException(_except_); \ - } \ - return 0; // NOLINT(*) - -/*! - * \brief handle exception throwed out - * \param e the exception - * \return the return value of API after exception is handled - */ -int TVMAPIHandleException(const std::exception& e); - -#endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/spirv/spirv_shader.h index 293dc5b78638..d194f70629e6 100644 --- a/src/runtime/spirv/spirv_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ #define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ -#include +#include #include #include #include diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 2e469d6e956a..3eae0cb73940 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -24,10 +24,10 @@ */ #include "./static_library.h" +#include #include #include #include -#include #include @@ -127,8 +127,8 @@ Module LoadStaticLibrary(const std::string& filename, Array func_names) return Module(node); } -TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") .set_body_typed(StaticLibraryNode::LoadFromBinary); } // namespace runtime diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index d3693c82564d..46c08e4afd9a 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -21,9 +21,9 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ +#include #include #include -#include #include @@ -112,13 +112,14 @@ class SystemLibModuleRegistry { std::unordered_map lib_map_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string symbol_prefix = ""; - if (args.size() != 0) { - symbol_prefix = args[0].cast(); - } - *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.SystemLib") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::string symbol_prefix = ""; + if (args.size() != 0) { + symbol_prefix = args[0].cast(); + } + *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 1e02d6a66154..e3a3a0a8fc63 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -23,11 +23,11 @@ */ #include #include +#include +#include #include -#include #include #include -#include #include #if TVM_THREADPOOL_USE_OPENMP #include @@ -379,7 +379,7 @@ class ThreadPool { * \brief args[0] is the AffinityMode, args[1] is the number of threads. * args2 is a list of CPUs which is used to set the CPU affinity. */ -TVM_REGISTER_GLOBAL("runtime.config_threadpool") +TVM_FFI_REGISTER_GLOBAL("runtime.config_threadpool") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { threading::ThreadGroup::AffinityMode mode = static_cast(args[0].cast()); @@ -395,7 +395,7 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool") threading::Configure(mode, nthreads, cpus); }); -TVM_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { +TVM_FFI_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { return threading::NumThreads(); }); diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 0bea0dac1b70..ef835f20d171 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -21,8 +21,9 @@ * \file threading_backend.cc * \brief Native threading backend */ +#include +#include #include -#include #include #if defined(__linux__) || defined(__ANDROID__) @@ -436,7 +437,7 @@ int MaxConcurrency() { // This global function can be used by disco runtime to bind processes // to CPUs. -TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") +TVM_FFI_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") .set_body_typed([](ffi::Shape cpu_ids) { SetThreadAffinity(CURRENT_THREAD_HANDLE, std::vector{cpu_ids.begin(), cpu_ids.end()}); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index f1e0ef587ecc..c1961d8065ae 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -20,11 +20,11 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ #define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ -#include +#include +#include #include #include #include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index fcb3e764bf86..12181f8c159d 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -455,12 +455,13 @@ VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { return const_cast(const_cast(this)->device(device_id)); } -TVM_REGISTER_GLOBAL("device_api.vulkan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.vulkan") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = VulkanDeviceAPI::Global(); + *rv = static_cast(ptr); + }); -TVM_REGISTER_GLOBAL("device_api.vulkan.get_target_property") +TVM_FFI_REGISTER_GLOBAL("device_api.vulkan.get_target_property") .set_body_typed([](Device dev, const std::string& property) { ffi::Any rv; VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 600d7d6f870c..063dc5bde009 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -20,7 +20,7 @@ #include "vulkan_module.h" #include -#include +#include #include "../file_utils.h" #include "vulkan_wrapped_func.h" @@ -64,9 +64,9 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index ab212c2eade4..f4922a1bf01d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -289,7 +289,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); for (DLDataType arg_type : fit->second.arg_types) { - if (arg_type.code == kTVMOpaqueHandle) { + if (arg_type.code == kDLOpaqueHandle) { push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); ++num_buffer; } else { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 2edc71cdb1b3..13f272d7c946 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include namespace tvm { @@ -101,20 +101,24 @@ void Namer::Name(ObjectRef node, String name) { TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode); TVM_REGISTER_NODE_TYPE(IRBuilderNode); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") .set_body_method(&IRBuilderFrameNode::EnterWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") .set_body_method(&IRBuilderFrameNode::ExitWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") .set_body_method(&IRBuilderFrameNode::AddCallback); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter") + .set_body_method(&IRBuilder::EnterWithScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit") + .set_body_method(&IRBuilder::ExitWithScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope") + .set_body_typed(IRBuilder::IsInScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderName") + .set_body_typed(IRBuilder::Name); } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 2b02a80e3eaf..6cb61147a96a 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include namespace tvm { diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index e5f8759aec8e..270f4623ef0c 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include -#include #include #include #include @@ -165,14 +165,14 @@ VDevice LookupVDevice(String target_kind, int device_index) { return VDevice(); } -TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index 74caa95e9012..fcf9e0eb2c5b 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -54,7 +54,7 @@ Expr MakeCallTIRDist(Expr func, Tuple args, Arrayoutput = std::move(normalized_value); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo") + .set_body_typed(FuncRetStructInfo); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); ///////////////////////////// BindingBlock ////////////////////////////// @@ -191,9 +192,9 @@ void DataflowBlockOutput(const Array& vars) { } } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") .set_body_typed(DataflowBlockOutput); /////////////////////////////// Bindings /////////////////////////////// @@ -236,9 +237,9 @@ tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { return binding->var; } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); /////////////////////////////// SeqExpr /////////////////////////////// @@ -247,7 +248,7 @@ SeqExprFrame SeqExpr() { return SeqExprFrame(n); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); ///////////////////////////// If Then Else ///////////////////////////// @@ -269,9 +270,9 @@ ElseFrame Else() { return ElseFrame(n); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); } // namespace relax } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index b4f94b2b893d..da772f608579 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -657,9 +657,9 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (auto var = obj.as()) { @@ -671,45 +671,45 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); throw; }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") .set_body_typed([](ffi::Variant thread_tag_or_var, PrimExpr extent) { if (auto var = thread_tag_or_var.as()) { return LaunchThread(var.value(), extent); @@ -721,60 +721,60 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") throw; } }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); #define TVM_TMP_STR(x) #x -#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); - -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); - -#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); - -#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); - -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.min") +#define TVM_FFI_REGISTER_GLOBAL_SIZE(Prefix, DType) \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); + +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); + +#define TVM_FFI_REGISTER_GLOBAL_LANES(Prefix, Func) \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); + +#define TVM_FFI_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); + +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min") .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.max") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.max") .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); } // namespace tir } // namespace ir_builder diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 89f517fc44ea..8f1fd77d782d 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -17,8 +17,8 @@ * under the License. */ #include +#include #include -#include #include namespace tvm { @@ -234,49 +234,50 @@ DocStringDoc::DocStringDoc(String docs) { } TVM_REGISTER_NODE_TYPE(DocNode); -TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") +TVM_FFI_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") .set_body_typed([](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocAttr") .set_body_method(&ExprDocNode::Attr); -TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::operator[]); -TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::operator[]); +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocCall") .set_body_method, Array, Array>( &ExprDocNode::Call); TVM_REGISTER_NODE_TYPE(StmtDocNode); -TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment") +TVM_FFI_REGISTER_GLOBAL("script.printer.StmtDocSetComment") .set_body_typed([](StmtDoc doc, Optional comment) { doc->comment = comment; }); TVM_REGISTER_NODE_TYPE(StmtBlockDocNode); -TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { +TVM_FFI_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { return StmtBlockDoc(stmts); }); TVM_REGISTER_NODE_TYPE(LiteralDocNode); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); -TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); +TVM_FFI_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { + return IdDoc(name); +}); TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); -TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) { - return AttrAccessDoc(value, attr); -}); +TVM_FFI_REGISTER_GLOBAL("script.printer.AttrAccessDoc") + .set_body_typed([](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); TVM_REGISTER_NODE_TYPE(IndexDocNode); -TVM_REGISTER_GLOBAL("script.printer.IndexDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.IndexDoc") .set_body_typed([](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); TVM_REGISTER_NODE_TYPE(CallDocNode); -TVM_REGISTER_GLOBAL("script.printer.CallDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.CallDoc") .set_body_typed([](ExprDoc callee, // Array args, // Array kwargs_keys, // @@ -285,104 +286,103 @@ TVM_REGISTER_GLOBAL("script.printer.CallDoc") }); TVM_REGISTER_NODE_TYPE(OperationDocNode); -TVM_REGISTER_GLOBAL("script.printer.OperationDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.OperationDoc") .set_body_typed([](int32_t kind, Array operands) { return OperationDoc(OperationDocNode::Kind(kind), operands); }); TVM_REGISTER_NODE_TYPE(LambdaDocNode); -TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array args, ExprDoc body) { - return LambdaDoc(args, body); -}); +TVM_FFI_REGISTER_GLOBAL("script.printer.LambdaDoc") + .set_body_typed([](Array args, ExprDoc body) { return LambdaDoc(args, body); }); TVM_REGISTER_NODE_TYPE(TupleDocNode); -TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { +TVM_FFI_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { return TupleDoc(elements); }); TVM_REGISTER_NODE_TYPE(ListDocNode); -TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { return ListDoc(elements); }); TVM_REGISTER_NODE_TYPE(DictDocNode); -TVM_REGISTER_GLOBAL("script.printer.DictDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.DictDoc") .set_body_typed([](Array keys, Array values) { return DictDoc(keys, values); }); TVM_REGISTER_NODE_TYPE(SliceDocNode); -TVM_REGISTER_GLOBAL("script.printer.SliceDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.SliceDoc") .set_body_typed([](Optional start, Optional stop, Optional step) { return SliceDoc(start, stop, step); }); TVM_REGISTER_NODE_TYPE(AssignDocNode); -TVM_REGISTER_GLOBAL("script.printer.AssignDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.AssignDoc") .set_body_typed([](ExprDoc lhs, Optional rhs, Optional annotation) { return AssignDoc(lhs, rhs, annotation); }); TVM_REGISTER_NODE_TYPE(IfDocNode); -TVM_REGISTER_GLOBAL("script.printer.IfDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.IfDoc") .set_body_typed([](ExprDoc predicate, Array then_branch, Array else_branch) { return IfDoc(predicate, then_branch, else_branch); }); TVM_REGISTER_NODE_TYPE(WhileDocNode); -TVM_REGISTER_GLOBAL("script.printer.WhileDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.WhileDoc") .set_body_typed([](ExprDoc predicate, Array body) { return WhileDoc(predicate, body); }); TVM_REGISTER_NODE_TYPE(ForDocNode); -TVM_REGISTER_GLOBAL("script.printer.ForDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ForDoc") .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array body) { return ForDoc(lhs, rhs, body); }); TVM_REGISTER_NODE_TYPE(ScopeDocNode); -TVM_REGISTER_GLOBAL("script.printer.ScopeDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ScopeDoc") .set_body_typed([](Optional lhs, ExprDoc rhs, Array body) { return ScopeDoc(lhs, rhs, body); }); TVM_REGISTER_NODE_TYPE(ExprStmtDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { return ExprStmtDoc(expr); }); TVM_REGISTER_NODE_TYPE(AssertDocNode); -TVM_REGISTER_GLOBAL("script.printer.AssertDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.AssertDoc") .set_body_typed([](ExprDoc test, Optional msg = std::nullopt) { return AssertDoc(test, msg); }); TVM_REGISTER_NODE_TYPE(ReturnDocNode); -TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { return ReturnDoc(value); }); TVM_REGISTER_NODE_TYPE(FunctionDocNode); -TVM_REGISTER_GLOBAL("script.printer.FunctionDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.FunctionDoc") .set_body_typed([](IdDoc name, Array args, Array decorators, Optional return_type, Array body) { return FunctionDoc(name, args, decorators, return_type, body); }); TVM_REGISTER_NODE_TYPE(ClassDocNode); -TVM_REGISTER_GLOBAL("script.printer.ClassDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ClassDoc") .set_body_typed([](IdDoc name, Array decorators, Array body) { return ClassDoc(name, decorators, body); }); TVM_REGISTER_NODE_TYPE(CommentDocNode); -TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { +TVM_FFI_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { return CommentDoc(comment); }); TVM_REGISTER_NODE_TYPE(DocStringDocNode); -TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { +TVM_FFI_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { return DocStringDoc(docs); }); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 2bb8e2a1dc51..85b5b755d253 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include #include @@ -727,7 +727,7 @@ String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { return result.substr(0, last_space); } -TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); +TVM_FFI_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); } // namespace printer } // namespace script diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 15762705b418..8c72eb4ef318 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include #include diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index 9b26a942be82..3d7abe821745 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -82,7 +82,7 @@ TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::TensorTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); -TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); +TVM_FFI_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); } // namespace printer } // namespace script diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index f35c00b4fb22..d0d9a35db83e 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -22,10 +22,10 @@ * \file ffi_testing.cc */ #include +#include #include #include #include -#include #include #include @@ -53,13 +53,22 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback") +TVM_FFI_REGISTER_GLOBAL("testing.GetShapeSize").set_body_typed([](ffi::Shape shape) { + return static_cast(shape.size()); +}); + +TVM_FFI_REGISTER_GLOBAL("testing.GetShapeElem").set_body_typed([](ffi::Shape shape, int idx) { + ICHECK_LT(idx, shape.size()); + return shape[idx]; +}); + +TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); *ret = ffi::TypedFunction([pf]() { pf(); }); }); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") +TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); auto result = ffi::TypedFunction([pf]() { @@ -71,22 +80,23 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") *ret = result; }); -TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") +TVM_FFI_REGISTER_GLOBAL("testing.test_check_eq_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto msg = args[0].cast(); *ret = ffi::TypedFunction([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); }); -TVM_REGISTER_GLOBAL("testing.device_test").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto dev = args[0].cast(); - int dtype = args[1].cast(); - int did = args[2].cast(); - CHECK_EQ(static_cast(dev.device_type), dtype); - CHECK_EQ(static_cast(dev.device_id), did); - *ret = dev; -}); +TVM_FFI_REGISTER_GLOBAL("testing.device_test") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto dev = args[0].cast(); + int dtype = args[1].cast(); + int did = args[2].cast(); + CHECK_EQ(static_cast(dev.device_type), dtype); + CHECK_EQ(static_cast(dev.device_id), did); + *ret = dev; + }); -TVM_REGISTER_GLOBAL("testing.identity_cpp") +TVM_FFI_REGISTER_GLOBAL("testing.identity_cpp") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); ICHECK(identity_func.has_value()) @@ -105,7 +115,7 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); +TVM_FFI_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); class FrontendTestModuleNode : public runtime::ModuleNode { public: @@ -145,22 +155,23 @@ runtime::Module NewFrontendTestModule() { return runtime::Module(n); } -TVM_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule); +TVM_FFI_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule); -TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { +TVM_FFI_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::chrono::duration duration(static_cast(timeout * 1e9)); std::this_thread::sleep_for(duration); }); -TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { - if (x % 2 == 0) { - return IntImm(DataType::Int(64), x / 2); - } else { - return String("argument was odd"); - } -}); +TVM_FFI_REGISTER_GLOBAL("testing.ReturnsVariant") + .set_body_typed([](int x) -> Variant { + if (x % 2 == 0) { + return IntImm(DataType::Int(64), x / 2); + } else { + return String("argument was odd"); + } + }); -TVM_REGISTER_GLOBAL("testing.AcceptsVariant") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { if (auto opt_str = arg.as()) { return opt_str.value()->GetTypeKey(); @@ -169,25 +180,25 @@ TVM_REGISTER_GLOBAL("testing.AcceptsVariant") } }); -TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); -TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); -TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array arg) -> Any { +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array arg) -> Any { return arg[0]; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") .set_body_typed([](Map map, Any key) -> Any { return map[key]; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") .set_body_typed([](Map map) -> ObjectRef { return map; }); -TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { return expr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") .set_body_typed([](Array arr) -> ObjectRef { for (ObjectRef item : arr) { CHECK(item->IsInstance()) @@ -196,7 +207,7 @@ TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") return arr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") .set_body_typed([](Array> arr) -> ObjectRef { for (auto item : arr) { CHECK(item.as() || item.as()) @@ -205,7 +216,7 @@ TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") return arr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") .set_body_typed([](Map map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; @@ -254,19 +265,21 @@ class TestingEventLogger { std::vector entries_; }; -TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() != 0 && args[0].try_cast()) { - TestingEventLogger::ThreadLocal()->Record(args[0].cast()); - } else { - TestingEventLogger::ThreadLocal()->Record("X"); - } -}); +TVM_FFI_REGISTER_GLOBAL("testing.record_event") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + if (args.size() != 0 && args[0].try_cast()) { + TestingEventLogger::ThreadLocal()->Record(args[0].cast()); + } else { + TestingEventLogger::ThreadLocal()->Record("X"); + } + }); -TVM_REGISTER_GLOBAL("testing.reset_events").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - TestingEventLogger::ThreadLocal()->Reset(); -}); +TVM_FFI_REGISTER_GLOBAL("testing.reset_events") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + TestingEventLogger::ThreadLocal()->Reset(); + }); -TVM_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { TestingEventLogger::ThreadLocal()->Dump(); }); } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index e78cfdd016f1..01b49bb92e79 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -17,9 +17,9 @@ * under the License. */ #include +#include #include #include -#include #ifndef TVM_INFO_GIT_COMMIT_HASH #define TVM_INFO_GIT_COMMIT_HASH "NOT-FOUND" @@ -367,6 +367,6 @@ TVM_DLL ffi::Map GetLibInfo() { return result; } -TVM_REGISTER_GLOBAL("support.GetLibInfo").set_body_typed(GetLibInfo); +TVM_FFI_REGISTER_GLOBAL("support.GetLibInfo").set_body_typed(GetLibInfo); } // namespace tvm diff --git a/src/support/socket.h b/src/support/socket.h index e3972488d4b8..e9e2f87f9dbf 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -47,8 +47,9 @@ #include #include #endif + +#include #include -#include #include #include diff --git a/src/target/build_common.h b/src/target/build_common.h index 7c9ad8cb3c68..70f15d091ed2 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -24,8 +24,8 @@ #ifndef TVM_TARGET_BUILD_COMMON_H_ #define TVM_TARGET_BUILD_COMMON_H_ +#include #include -#include #include #include #include diff --git a/src/target/codegen.cc b/src/target/codegen.cc index a6b5b4b041c8..8ddc071cba0f 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -22,10 +22,10 @@ * \brief Common utilities to generated C style code. */ #include +#include #include -#include +#include #include -#include #include #include #include @@ -361,14 +361,14 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, .cast(); } -TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); +TVM_FFI_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export a few auxiliary function to the runtime namespace. -TVM_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { return runtime::symbol::tvm_dev_mblob; }); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") .set_body_typed([](const runtime::Module& mod) { std::string buffer = PackImportsToBytes(mod); ffi::Shape::index_type size = buffer.size(); @@ -384,8 +384,8 @@ TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") return array; }); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/myfloat/myfloat.cc b/src/target/datatype/myfloat/myfloat.cc index c0c2fffa03da..afee8a7c4bf0 100644 --- a/src/target/datatype/myfloat/myfloat.cc +++ b/src/target/datatype/myfloat/myfloat.cc @@ -26,7 +26,7 @@ * * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? */ -#include +#include #include #include diff --git a/src/target/datatype/posit/posit-wrapper.cc b/src/target/datatype/posit/posit-wrapper.cc index 700c5cb9dbe9..bb2af37ec921 100644 --- a/src/target/datatype/posit/posit-wrapper.cc +++ b/src/target/datatype/posit/posit-wrapper.cc @@ -28,7 +28,7 @@ * * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? */ -#include +#include #include diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 79065d0024c5..2c1fc84d4084 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -18,7 +18,7 @@ */ #include "registry.h" -#include +#include namespace tvm { namespace datatype { @@ -26,23 +26,23 @@ namespace datatype { using ffi::Any; using ffi::PackedArgs; -TVM_REGISTER_GLOBAL("dtype.register_custom_type") +TVM_FFI_REGISTER_GLOBAL("dtype.register_custom_type") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { datatype::Registry::Global()->Register(args[0].cast(), static_cast(args[1].cast())); }); -TVM_REGISTER_GLOBAL("dtype.get_custom_type_code") +TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_code") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); }); -TVM_REGISTER_GLOBAL("dtype.get_custom_type_name") +TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_name") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeName(args[0].cast()); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") +TVM_FFI_REGISTER_GLOBAL("runtime._datatype_get_type_registered") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); }); diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index 46b189880f64..eba7739a6b86 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -20,8 +20,8 @@ #ifndef TVM_TARGET_DATATYPE_REGISTRY_H_ #define TVM_TARGET_DATATYPE_REGISTRY_H_ +#include #include -#include #include #include @@ -38,7 +38,7 @@ namespace datatype { * directly---see the TVM globals registered in the corresponding .cc file. * Currently, user should manually choose a type name and a type code, * ensuring that neither conflict with existing types. - * 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to + * 2. Use TVM_FFI_REGISTER_GLOBAL to register the lowering functions needed to * lower the custom datatype. In general, these will look like: * For Casts: tvm.datatype.lower..Cast.. * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index ea8ccd98b1af..ac45476f7702 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_INTRIN_RULE_H_ #define TVM_TARGET_INTRIN_RULE_H_ -#include +#include #include #include diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index b690c0fc28b1..9d968cdb6478 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -25,7 +25,7 @@ #include #include -#include +#include #include "../../arith/scalable_expression.h" #include "codegen_cpu.h" @@ -106,7 +106,7 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { this->VisitStmt(op->body); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAArch64()); }); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 35b5d1378423..048c4160b118 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -47,9 +47,9 @@ #endif #include #include -#include +#include +#include #include -#include #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" @@ -356,9 +356,9 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); +TVM_FFI_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAMDGPU()); }); diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 3abebec2a36e..03ef982d1308 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -24,7 +24,7 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -132,7 +132,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenARM()); }); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 0a5223ae029b..bfbd65e524fb 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -49,7 +49,7 @@ #include #include #include -#include +#include #include #include @@ -75,12 +75,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, CodeGenLLVM::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); system_lib_prefix_ = system_lib_prefix; dbg_info_ = CreateDebugInfo(module_.get()); - static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); export_system_symbols_.clear(); // Runtime types. - t_tvm_shape_index_ = llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: @@ -89,7 +87,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: // typedef struct { uint8_t code; uint8_t bits; uint16_t lanes; } DLDataType; t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); - // Defined in include/tvm/runtime/c_runtime_api.h: + // Defined in include/tvm/runtime/base.h: // typedef void* TVMFunctionHandle; t_tvm_func_handle_ = t_void_p_; // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: @@ -1158,7 +1156,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenCPU()); }); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 68bf973439f3..baf7497bc0d1 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -589,9 +589,9 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } -TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); +TVM_FFI_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenHexagon()); }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f77e32efd587..634c9c2b57a5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -90,7 +90,7 @@ #include #include #include -#include +#include #include #include @@ -2325,19 +2325,18 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) return nullptr; } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetDefaultTargetTriple").set_body_typed([]() -> std::string { - return llvm::sys::getDefaultTargetTriple(); -}); +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetDefaultTargetTriple") + .set_body_typed([]() -> std::string { return llvm::sys::getDefaultTargetTriple(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetProcessTriple").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetProcessTriple").set_body_typed([]() -> std::string { return llvm::sys::getProcessTriple(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> std::string { return llvm::sys::getHostCPUName().str(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") .set_body_typed([]() -> Map { #if TVM_LLVM_VERSION >= 190 Map ret; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 865b383cd334..a0ffb5a1ce10 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -368,9 +368,9 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); +TVM_FFI_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenNVPTX()); }); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 954d4e7efd56..435b453d49ba 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -30,7 +30,7 @@ #include #endif #include -#include +#include #include #include @@ -132,7 +132,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenX86_64()); }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index a0e040a2048e..4b64e92127d3 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -26,7 +26,7 @@ #ifdef TVM_LLVM_VERSION -#include +#include #include #include #include diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index e7be40fb9041..48fc64172215 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -22,7 +22,7 @@ */ #ifdef TVM_LLVM_VERSION -#include +#include #include #include #include diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index c80d8388da9c..30afcee92acc 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -23,7 +23,7 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #include #include #include diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index e8bfc9f19e66..e5b2fc47ec8d 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -56,13 +56,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -621,14 +621,14 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, return nullptr; } -TVM_REGISTER_GLOBAL("target.build.llvm") +TVM_FFI_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { auto n = make_object(); n->Init(mod, target); return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") +TVM_FFI_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); @@ -643,7 +643,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") +TVM_FFI_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") .set_body_typed([](std::string name) -> int64_t { #if TVM_LLVM_VERSION >= 200 return static_cast(llvm::Intrinsic::lookupIntrinsicID(name)); @@ -652,7 +652,7 @@ TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") #endif }); -TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String { #if TVM_LLVM_VERSION >= 130 return std::string(llvm::Intrinsic::getBaseName(static_cast(id))); #elif TVM_LLVM_VERSION >= 40 @@ -667,7 +667,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> String { #if TVM_LLVM_VERSION >= 120 #if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) using namespace llvm::sys::detail::x86; @@ -683,34 +683,35 @@ TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> return "unimplemented"; }); -TVM_REGISTER_GLOBAL("target.llvm_get_vector_width").set_body_typed([](const Target& target) -> int { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return -1; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - return llvm_backend.GetVectorWidth(); -}); +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_vector_width") + .set_body_typed([](const Target& target) -> int { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return -1; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetVectorWidth(); + }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String { return llvm::sys::getDefaultTargetTriple(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_cpu").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_cpu").set_body_typed([]() -> String { return llvm::sys::getHostCPUName().str(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { auto llvm_instance = std::make_unique(); LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); return llvm_backend.GetAllLLVMTargets(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") .set_body_typed([](const Target& target) -> Array { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -724,7 +725,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") return llvm_backend.GetAllLLVMTargetArches(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_features") .set_body_typed([](const Target& target) -> Map { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -738,7 +739,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") return llvm_backend.GetAllLLVMCpuFeatures(); }); -TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") +TVM_FFI_REGISTER_GLOBAL("target.llvm_cpu_has_feature") .set_body_typed([](const String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -754,7 +755,7 @@ TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") return has_feature; }); -TVM_REGISTER_GLOBAL("target.target_has_feature") +TVM_FFI_REGISTER_GLOBAL("target.target_has_feature") .set_body_typed([](const String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -768,11 +769,11 @@ TVM_REGISTER_GLOBAL("target.target_has_feature") return llvm_target.TargetHasCPUFeature(feature); }); -TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { +TVM_FFI_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { return TVM_LLVM_VERSION / 10; }); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ll") .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { auto n = make_object(); n->SetJITEngine("orcjit"); @@ -780,7 +781,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") +TVM_FFI_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { LLVMInstance llvm_instance; auto* tm = With(llvm_instance, target_str) @@ -788,7 +789,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") return tm != nullptr; }); -TVM_REGISTER_GLOBAL("codegen.codegen_blob") +TVM_FFI_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string, std::string c_symbol_prefix) -> runtime::Module { auto n = make_object(); diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 58f80858d087..068f6c2f7196 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -169,7 +169,7 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); +TVM_FFI_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", String); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ad9456a1e9a5..344c0857c4d4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -338,22 +338,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ")"; return os.str(); } else { - ICHECK_LT(kind, builtin::kTVMValueKindBound_); - std::ostringstream os; - os << "(((TVMValue*)"; - this->PrintExpr(buffer, os); - os << ")[" << index << "]."; - if (t.is_handle()) { - os << "v_handle"; - } else if (t.is_float()) { - os << "v_float64"; - } else if (t.is_int()) { - os << "v_int64"; - } else { - LOG(FATAL) << "Do not know how to handle type" << t; - } - os << ")"; - return os.str(); + TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; } } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index ef86f09ca28e..ad73fc9079e9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -43,7 +43,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; - decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + decl_stream << "#include \"tvm/runtime/base.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; decl_stream << "#include \n"; @@ -285,24 +285,20 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); ICHECK(num != nullptr); - static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant"); - size_t unit = sizeof(TVMValue); + static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); + size_t unit = sizeof(TVMFFIAny); size_t size = 0; if (type == "shape") { - size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit; - } else if (type == "arg_value") { - size = (num->value * sizeof(TVMValue) + unit - 1) / unit; + size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit; } else if (type == "tvm_ffi_any") { size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; - } else if (type == "arg_tcode") { - size = (num->value * sizeof(int) + unit - 1) / unit; } else if (type == "array") { size = (num->value * sizeof(DLTensor) + unit - 1) / unit; } else { LOG(FATAL) << "Unknown stack alloca type " << type; } this->PrintIndent(); - this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; + this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; os << stack_name; } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { this->PrintCallPacked(op); @@ -408,6 +404,6 @@ runtime::Module BuildCHost(IRModule mod, Target target) { return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } -TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); +TVM_FFI_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6bbb388a94cb..c3014b11a5be 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -24,7 +24,7 @@ #include "codegen_cuda.h" #include -#include +#include #include #include diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index d115916def54..0f87a16c449b 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -466,6 +466,6 @@ runtime::Module BuildMetal(IRModule mod, Target target) { return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } -TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); +TVM_FFI_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 9814696b3728..b94dc17bff33 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -672,7 +672,7 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str()); } -TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); +TVM_FFI_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { auto prototype_keys = target->GetKeys(); @@ -684,7 +684,7 @@ String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { return memory_scope; } -TVM_REGISTER_GLOBAL("DeviceScopeCompatibility.opencl") +TVM_FFI_REGISTER_GLOBAL("DeviceScopeCompatibility.opencl") .set_body_typed(DeviceScopeCompatibilityFromTarget); } // namespace codegen diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 8d1ad91746b6..995eddee027e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -779,7 +779,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { return runtime::Module(n); } -TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { +TVM_FFI_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { return BuildWebGPU(mod, target); }); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ec3cadb8c8e4..054edd861adc 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,10 +23,10 @@ */ #include +#include #include #include #include -#include #include #include @@ -175,7 +175,7 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_c") .set_body_typed(CSourceModuleNode::LoadFromBinary); /*! @@ -249,9 +249,9 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") +TVM_FFI_REGISTER_GLOBAL("runtime.CSourceModuleCreate") .set_body_typed([](String code, String fmt, Optional> func_names, Optional> const_vars) { return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 5690ef05de5c..f3dbd624ec00 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -35,7 +35,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target) { return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } -TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { +TVM_FFI_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ccb8d131c9d1..3010b74dd976 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -21,7 +21,7 @@ * \file intrin_rule_spirv.cc */ #include -#include +#include #include #include #include diff --git a/src/target/tag.cc b/src/target/tag.cc index 1099de85f2b4..f6e2307b75e1 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -22,8 +22,8 @@ * \brief Target tag registry */ +#include #include -#include #include #include @@ -33,8 +33,8 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetTagNode); -TVM_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); -TVM_REGISTER_GLOBAL("target.TargetTagAddTag").set_body_typed(TargetTag::AddTag); +TVM_FFI_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); +TVM_FFI_REGISTER_GLOBAL("target.TargetTagAddTag").set_body_typed(TargetTag::AddTag); /********** Registry-related code **********/ diff --git a/src/target/target.cc b/src/target/target.cc index 6885cf0cffe9..d9e3f9b51ee7 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,10 +21,10 @@ * \file src/target/target.cc */ #include +#include #include #include #include -#include #include #include #include @@ -40,8 +40,6 @@ #include #include -#include "../runtime/object_internal.h" - namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); @@ -1010,16 +1008,16 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, /********** Registry **********/ -TVM_REGISTER_GLOBAL("target.Target").set_body_packed(TargetInternal::ConstructorDispatcher); -TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); -TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); -TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); -TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); -TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); -TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { +TVM_FFI_REGISTER_GLOBAL("target.Target").set_body_packed(TargetInternal::ConstructorDispatcher); +TVM_FFI_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); +TVM_FFI_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); +TVM_FFI_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); +TVM_FFI_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); +TVM_FFI_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); +TVM_FFI_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { return target->GetTargetDeviceType(); }); -TVM_REGISTER_GLOBAL("target.TargetGetFeature") +TVM_FFI_REGISTER_GLOBAL("target.TargetGetFeature") .set_body_typed([](const Target& target, const String& feature_key) -> Any { if (auto opt_any = target->GetFeature(feature_key)) { return opt_any.value(); diff --git a/src/target/target_info.cc b/src/target/target_info.cc index a63e45a81a4a..6e673905d3c2 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -20,8 +20,8 @@ /*! * \file target/target_info.cc */ +#include #include -#include #include namespace tvm { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 6f206b567c96..cdec2ede0643 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -21,9 +21,9 @@ * \file src/target/target_kind.cc * \brief Target kind registry */ +#include #include #include -#include #include #include @@ -446,7 +446,7 @@ TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break /********** Registry **********/ -TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") +TVM_FFI_REGISTER_GLOBAL("target.TargetKindGetAttr") .set_body_typed([](TargetKind kind, String attr_name) -> ffi::Any { auto target_attr_map = TargetKind::GetAttrMap(attr_name); ffi::Any rv; @@ -455,10 +455,11 @@ TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") } return rv; }); -TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); -TVM_REGISTER_GLOBAL("target.ListTargetKindOptions") +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKinds") + .set_body_typed(TargetKindRegEntry::ListTargetKinds); +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptions") .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); -TVM_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName") +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName") .set_body_typed([](String target_kind_name) { TargetKind kind = TargetKind::Get(target_kind_name).value(); return TargetKindRegEntry::ListTargetKindOptions(kind); diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index 3842776a6fd4..a39756662621 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -191,7 +191,7 @@ VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { virtual_device->target, virtual_device->memory_scope); } -TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") +TVM_FFI_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); } // namespace tvm diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 765113cddc35..294b34bf5d2e 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include #include @@ -148,7 +148,7 @@ ComputeOp::ComputeOp(std::string name, std::string tag, Map at data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp") +TVM_FFI_REGISTER_GLOBAL("te.ComputeOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, Array axis, Array body) { return ComputeOp(name, tag, attrs.value_or({}), axis, body); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 05fd5ae64a82..1534cfc35889 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -20,8 +20,8 @@ #include "create_primfunc.h" #include +#include #include -#include #include #include #include @@ -784,15 +784,16 @@ PrimFunc CreatePrimFunc(const Array& arg_list, return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - Array arg_list = args[0].cast>(); - std::optional index_dtype_override{std::nullopt}; - // Add conversion to make std::optional compatible with FFI. - if (args[1] != nullptr) { - index_dtype_override = args[1].cast(); - } - *ret = CreatePrimFunc(arg_list, index_dtype_override); -}); +TVM_FFI_REGISTER_GLOBAL("te.CreatePrimFunc") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + Array arg_list = args[0].cast>(); + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[1] != nullptr) { + index_dtype_override = args[1].cast(); + } + *ret = CreatePrimFunc(arg_list, index_dtype_override); + }); // Relax version impl PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 2bee0555570e..9f8531998e88 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -22,7 +22,7 @@ * \file extern_op.cc */ #include -#include +#include #include #include @@ -70,7 +70,7 @@ ExternOp::ExternOp(std::string name, std::string tag, Map attr data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ExternOp") +TVM_FFI_REGISTER_GLOBAL("te.ExternOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body) { diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index aee7f2afb188..e2bbced85f89 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -23,7 +23,7 @@ */ #include "graph.h" -#include +#include #include #include #include @@ -80,9 +80,9 @@ Array PostDFSOrder(const Array& roots, const ReadGraph& g) return post_order; } -TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); +TVM_FFI_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); -TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") +TVM_FFI_REGISTER_GLOBAL("schedule.PostDFSOrder") .set_body_typed([](const Array& roots, const ReadGraph& g) { return PostDFSOrder(roots, g); }); diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 466def8b3014..cce70420c0bd 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -22,7 +22,7 @@ * \file placeholder_op.cc */ #include -#include +#include #include namespace tvm { @@ -61,7 +61,7 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOp(name, shape, dtype).output(0); } -TVM_REGISTER_GLOBAL("te.Placeholder") +TVM_FFI_REGISTER_GLOBAL("te.Placeholder") .set_body_typed([](Variant> shape_arg, DataType dtype, std::string name) { auto shape = [&]() -> Array { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 5e6ad5c78f38..f4860cf71ef7 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -21,7 +21,7 @@ * \brief Scan Operator. * \file scan_op.cc */ -#include +#include #include #include @@ -97,7 +97,7 @@ ScanOp::ScanOp(std::string name, std::string tag, Optional data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ScanOp") +TVM_FFI_REGISTER_GLOBAL("te.ScanOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, IterVar axis, Array init, Array update, Array state_placeholder, Array inputs) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index f46c095f3b08..a23f4b494ece 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -20,7 +20,7 @@ /*! * \file tensor.cc */ -#include +#include #include #include @@ -98,7 +98,7 @@ Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_in data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.Tensor") +TVM_FFI_REGISTER_GLOBAL("te.Tensor") .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { return Tensor(shape, dtype, op, value_index); }); @@ -112,19 +112,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Other tensor ops. -TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); +TVM_FFI_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); -TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { +TVM_FFI_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { return static_cast(std::hash()(tensor)); }); -TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { +TVM_FFI_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); +TVM_FFI_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); +TVM_FFI_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index d8fcee859f03..ce13ac56c81d 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -402,8 +402,9 @@ Array> GetBlockReadWriteRegion(const Block& block, return {reads, writes}; } -TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); -TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion") + .set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index e8cba6116d65..aca4c99e1197 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -346,6 +346,7 @@ Map> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } -TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca") + .set_body_typed(DetectBufferAccessLCA); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index e304abcdeb7d..de208ce9c1e0 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -96,7 +96,7 @@ tvm::Map > CalculateAllocatedBytes(const IRMod return results; } -TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") .set_body_typed([](ObjectRef obj) -> tvm::Map > { if (auto func = obj.as()) { return CalculateAllocatedBytes(func.value()); @@ -155,7 +155,7 @@ Array GetVTCMCompactionPasses() { return pass_list; } -TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { return GetVTCMCompactionPasses(); }); @@ -191,7 +191,7 @@ Pass VerifyVTCMLimit(Optional default_target) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index e26d6d12e61f..a9c2b9ecc609 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -24,7 +24,7 @@ #include "control_flow_graph.h" -#include +#include #include #include #include diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index fb4cab759069..07d6500570f8 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,10 +21,10 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ +#include #include #include #include -#include #include namespace tvm { @@ -68,7 +68,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt); } -TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); }); diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index c8bea6ecd04d..688fcd226300 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -245,17 +245,18 @@ double EstimateTIRFlops(const IRModule& mod) { return PostprocessResults(result) + cached_result; } -TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { - if (auto mod = obj.as()) { - return EstimateTIRFlops(mod.value()); - } else if (auto stmt = obj.as()) { - return EstimateTIRFlops(stmt.value()); - } else { - LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " - << obj->GetTypeKey(); - throw; - } -}); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops") + .set_body_typed([](ObjectRef obj) -> double { + if (auto mod = obj.as()) { + return EstimateTIRFlops(mod.value()); + } else if (auto stmt = obj.as()) { + return EstimateTIRFlops(stmt.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " + << obj->GetTypeKey(); + throw; + } + }); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index 3ab96f634a31..dcffe1c1d6b8 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -282,7 +282,7 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an } // Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. -TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { +TVM_FFI_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { Array output; struct Visitor : arith::IRVisitorWithAnalyzer { diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index ee893987c91e..4af823604971 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -91,7 +91,7 @@ bool IsPureFunction(const PrimFunc& func, bool assert_on_error) { return PurityChecker::Check(func, assert_on_error); } -TVM_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index dbe114df4973..898a92adc7db 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -123,7 +123,7 @@ transform::Pass OOBChecker() { return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {}); } -TVM_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 7acbd2eb6ad2..b5a23e35d276 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -139,7 +139,7 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { return nullptr; } -TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { +TVM_FFI_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { auto ret = FindAnchorBlock(mod); if (ret) { return Optional(GetRef(ret)); diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 654d3332c755..0d75cebac798 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -199,7 +199,7 @@ Array UndefinedVars(const PrimExpr& expr, const Array& args) { return m.undefined_; } -TVM_REGISTER_GLOBAL("tir.analysis.UndefinedVars") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.UndefinedVars") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index d109736863c7..ef46a41687ad 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -24,7 +24,7 @@ * in a block exceeds the limit */ -#include +#include #include #include #include @@ -321,7 +321,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { return errs.size() == 0; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); namespace transform { @@ -346,7 +346,7 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index f8681189a1e6..bc567879c22b 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -21,8 +21,8 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ +#include #include -#include #include #include #include @@ -186,7 +186,7 @@ std::vector VerifyMemory_(const PrimFunc& func) { bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { @@ -211,7 +211,7 @@ Pass VerifyMemory() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index f238ffd763b1..33abb39c367f 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -23,7 +23,7 @@ * SSA requires each varaible to be only defined once. * \file verify_ssa.cc */ -#include +#include #include #include #include @@ -139,7 +139,7 @@ bool VerifySSA(const PrimFunc& func) { return visitor.is_ssa_; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); namespace transform { @@ -155,7 +155,7 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); } // namespace transform diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index cfdc2f35515a..a0c5f4829bf8 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -22,7 +22,7 @@ * \brief Check if schedulable tir is well-formed. */ -#include +#include #include #include @@ -368,7 +368,7 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { return true; } -TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") .set_body_typed([](const ObjectRef& obj, bool assert_mode) { if (auto opt = obj.as()) { return VerifyWellFormed(opt.value(), assert_mode); diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 9b1cb079fe28..dc1e3c48c924 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -85,11 +85,11 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { } TVM_REGISTER_NODE_TYPE(BlockDependenceInfoNode); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfo") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfo") .set_body_typed([](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfoGetBlockScope") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetBlockScope") .set_body_method(&BlockDependenceInfoNode::GetBlockScope); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfoGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetSRef") .set_body_typed([](BlockDependenceInfo self, Stmt stmt) -> Optional { auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index 5320c5d68a37..381fae73a475 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -190,18 +190,21 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode); TVM_REGISTER_NODE_TYPE(DependencyNode); TVM_REGISTER_NODE_TYPE(BlockScopeNode); -TVM_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional { +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }); -TVM_REGISTER_GLOBAL("tir.StmtSRefParent").set_body_typed([](StmtSRef sref) -> Optional { - return GetRef>(sref->parent); -}); -TVM_REGISTER_GLOBAL("tir.StmtSRefRootMark") // +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefParent") + .set_body_typed([](StmtSRef sref) -> Optional { + return GetRef>(sref->parent); + }); +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefRootMark") // .set_body_typed(StmtSRef::RootMark); -TVM_REGISTER_GLOBAL("tir.StmtSRefInlineMark") // +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefInlineMark") // .set_body_typed(StmtSRef::InlineMark); -TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc").set_body_method(&BlockScopeNode::GetDepsBySrc); -TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst").set_body_method(&BlockScopeNode::GetDepsByDst); +TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc") + .set_body_method(&BlockScopeNode::GetDepsBySrc); +TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst") + .set_body_method(&BlockScopeNode::GetDepsByDst); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 3b94c2ae757a..bce9c2c4e1a8 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -21,8 +21,8 @@ * \file buffer.cc */ #include +#include #include -#include #include #include #include @@ -640,7 +640,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std TVM_REGISTER_NODE_TYPE(BufferNode); -TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { +TVM_FFI_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; @@ -658,17 +658,18 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi:: axis_separators, span); }); -TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); +TVM_FFI_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); -TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer") + .set_body_method(&Buffer::GetFlattenedBuffer); -TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); +TVM_FFI_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); -TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); +TVM_FFI_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); -TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); +TVM_FFI_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); -TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); +TVM_FFI_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 119322455b1c..96f87344cbea 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -22,7 +22,7 @@ * \brief Data Layout expression. */ #include -#include +#include #include #include @@ -427,43 +427,45 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { +TVM_FFI_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { return Layout(name, dtype); }); -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::Get(axis)); -}); +TVM_FFI_REGISTER_GLOBAL("tir.LayoutIndexOf") + .set_body_typed([](Layout layout, std::string axis) -> int { + return layout.IndexOf(LayoutAxis::Get(axis)); + }); -TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") +TVM_FFI_REGISTER_GLOBAL("tir.LayoutFactorOf") .set_body_typed([](Layout layout, std::string axis) -> int { return layout.FactorOf(LayoutAxis::Get(axis)); }); -TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { +TVM_FFI_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { - const LayoutAxis& axis = layout[idx]; - return axis.name(); -}); +TVM_FFI_REGISTER_GLOBAL("tir.LayoutGetItem") + .set_body_typed([](Layout layout, int idx) -> std::string { + const LayoutAxis& axis = layout[idx]; + return axis.name(); + }); -TVM_REGISTER_GLOBAL("tir.BijectiveLayout") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayout") .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { return BijectiveLayout(src_layout, dst_layout); }); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") .set_body_method(&BijectiveLayout::ForwardIndex); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") .set_body_method(&BijectiveLayout::BackwardIndex); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") .set_body_method(&BijectiveLayout::ForwardShape); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 4b9dca8989a9..0ac59b160200 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -20,7 +20,7 @@ /*! * \file expr.cc */ -#include +#include #include #include #include @@ -43,7 +43,7 @@ namespace tir { * `expr.dtype` field), this function allows the FFI conversions to be * explicitly invoked. */ -TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { +TVM_FFI_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { return expr; }); @@ -127,7 +127,8 @@ Var Var::copy_with_dtype(DataType dtype) const { return Var(new_ptr); } -TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, + Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { @@ -156,7 +157,7 @@ SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { return SizeVar(s, t, span); }); @@ -182,7 +183,7 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.IterVar") +TVM_FFI_REGISTER_GLOBAL("tir.IterVar") .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) { return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); @@ -198,7 +199,7 @@ StringImm::StringImm(String value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { return StringImm(value, span); }); @@ -216,7 +217,7 @@ Cast::Cast(DataType t, PrimExpr value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { return Cast(dtype, value, span); }); @@ -225,7 +226,7 @@ TVM_REGISTER_NODE_TYPE(CastNode); // Add TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Add(a, b, span); }); @@ -234,7 +235,7 @@ TVM_REGISTER_NODE_TYPE(AddNode); // Sub TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Sub(a, b, span); }); @@ -243,7 +244,7 @@ TVM_REGISTER_NODE_TYPE(SubNode); // Mul TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); -TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Mul(a, b, span); }); @@ -252,7 +253,7 @@ TVM_REGISTER_NODE_TYPE(MulNode); // Div TVM_DEFINE_BINOP_CONSTRUCTOR(Div); -TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Div(a, b, span); }); @@ -261,7 +262,7 @@ TVM_REGISTER_NODE_TYPE(DivNode); // Mod TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); -TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Mod(a, b, span); }); @@ -270,7 +271,7 @@ TVM_REGISTER_NODE_TYPE(ModNode); // FloorDiv TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); -TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return FloorDiv(a, b, span); }); @@ -279,7 +280,7 @@ TVM_REGISTER_NODE_TYPE(FloorDivNode); // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return FloorMod(a, b, span); }); @@ -288,7 +289,7 @@ TVM_REGISTER_NODE_TYPE(FloorModNode); // Min TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Min(a, b, span); }); @@ -297,7 +298,7 @@ TVM_REGISTER_NODE_TYPE(MinNode); // Max TVM_DEFINE_BINOP_CONSTRUCTOR(Max); -TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Max(a, b, span); }); @@ -306,7 +307,7 @@ TVM_REGISTER_NODE_TYPE(MaxNode); // EQ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); -TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); @@ -315,7 +316,7 @@ TVM_REGISTER_NODE_TYPE(EQNode); // NE TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); -TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); @@ -324,7 +325,7 @@ TVM_REGISTER_NODE_TYPE(NENode); // LT TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); -TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); @@ -333,7 +334,7 @@ TVM_REGISTER_NODE_TYPE(LTNode); // LE TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); -TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); @@ -342,7 +343,7 @@ TVM_REGISTER_NODE_TYPE(LENode); // GT TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); -TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); @@ -351,7 +352,7 @@ TVM_REGISTER_NODE_TYPE(GTNode); // GE TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); -TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); @@ -374,7 +375,7 @@ And::And(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return And(a, b, span); }); @@ -397,7 +398,7 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); @@ -416,7 +417,9 @@ Not::Not(PrimExpr a, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { return Not(a, span); }); +TVM_FFI_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { + return Not(a, span); +}); TVM_REGISTER_NODE_TYPE(NotNode); @@ -442,7 +445,7 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Select") +TVM_FFI_REGISTER_GLOBAL("tir.Select") .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { return Select(condition, true_value, false_value, span); }); @@ -481,7 +484,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Ramp") +TVM_FFI_REGISTER_GLOBAL("tir.Ramp") .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); @@ -514,9 +517,10 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { data_ = node; } -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { - return Broadcast(value, lanes, span); -}); +TVM_FFI_REGISTER_GLOBAL("tir.Broadcast") + .set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { + return Broadcast(value, lanes, span); + }); TVM_REGISTER_NODE_TYPE(BroadcastNode); @@ -535,8 +539,8 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, - Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, + Span span) { return Let(var, value, body, span); }); @@ -556,7 +560,7 @@ Call::Call(DataType dtype, RelaxExpr op, Array args, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Call") +TVM_FFI_REGISTER_GLOBAL("tir.Call") .set_body_typed([](Optional dtype, RelaxExpr op, Array> args, Span span) { @@ -631,7 +635,7 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { return Shuffle({vector}, {Integer(index)}, span); } -TVM_REGISTER_GLOBAL("tir.Shuffle") +TVM_FFI_REGISTER_GLOBAL("tir.Shuffle") .set_body_typed([](Array vectors, Array indices, Span span) { return Shuffle(vectors, indices, span); }); @@ -691,13 +695,14 @@ Array CommReducerNode::operator()(Array a, Array b return Substitute(this->result, value_map); } -TVM_REGISTER_GLOBAL("tir.CommReducer") +TVM_FFI_REGISTER_GLOBAL("tir.CommReducer") .set_body_typed([](Array lhs, Array rhs, Array result, Array identity_element, Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }); -TVM_REGISTER_GLOBAL("tir.CommReducerCombine").set_body_method(&tir::CommReducerNode::operator()); +TVM_FFI_REGISTER_GLOBAL("tir.CommReducerCombine") + .set_body_method(&tir::CommReducerNode::operator()); TVM_REGISTER_NODE_TYPE(CommReducerNode); @@ -736,7 +741,7 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.Reduce") +TVM_FFI_REGISTER_GLOBAL("tir.Reduce") .set_body_typed([](CommReducer combiner, Array source, Array axis, PrimExpr condition, int value_index, Array init, Span span) { return Reduce(combiner, source, axis, condition, value_index, init, span); @@ -811,7 +816,7 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional indices, Optional predicate, Span span) { return BufferLoad(buffer, indices, predicate, span); }); @@ -827,7 +832,7 @@ ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.ProducerLoad") +TVM_FFI_REGISTER_GLOBAL("tir.ProducerLoad") .set_body_typed([](DataProducer producer, Array indices, Span span) { return ProducerLoad(producer, indices, span); }); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 5f85b74c8d27..2312d31fd276 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -21,8 +21,8 @@ * \file src/tir/ir/function.cc * \brief The function data structure. */ +#include #include -#include #include #include #include @@ -155,19 +155,19 @@ Optional TensorIntrin::Get(String name, bool allow_missing) { TVM_REGISTER_NODE_TYPE(TensorIntrinNode); -TVM_REGISTER_GLOBAL("tir.PrimFunc") +TVM_FFI_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, Map buffer_map, DictAttrs attrs, Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); -TVM_REGISTER_GLOBAL("tir.TensorIntrin") +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrin") .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { return TensorIntrin(desc_func, intrin_func); }); -TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); -TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index ceac7e4079ba..7297b62bf36d 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -419,33 +419,34 @@ IndexMap Substitute(const IndexMap& index_map, TVM_REGISTER_NODE_TYPE(IndexMapNode); -TVM_REGISTER_GLOBAL("tir.IndexMap") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMap") .set_body_typed([](Array initial_indices, Array final_indices, Optional inverse_index_map) { return IndexMap(initial_indices, final_indices, inverse_index_map); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapIndices") .set_body_typed([](IndexMap map, Array indices) { arith::Analyzer analyzer; return map->MapIndices(indices, &analyzer); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Array shape) { - arith::Analyzer analyzer; - return map->MapShape(shape, &analyzer); -}); +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapShape") + .set_body_typed([](IndexMap map, Array shape) { + arith::Analyzer analyzer; + return map->MapShape(shape, &analyzer); + }); -TVM_REGISTER_GLOBAL("tir.IndexMapInverse") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapInverse") .set_body_typed([](IndexMap map, Array initial_ranges) { arith::Analyzer analyzer; return map.Inverse(initial_ranges, &analyzer); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapNDArray") .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); -TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") .set_body_typed([](IndexMap forward, Array initial_ranges) { arith::Analyzer analyzer; auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c28f7e2b490d..7e8c2913e55f 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -160,7 +160,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } } -TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); +TVM_FFI_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h index 8df04566460a..273ca946a7ff 100644 --- a/src/tir/ir/script/script_complete.h +++ b/src/tir/ir/script/script_complete.h @@ -23,7 +23,7 @@ */ #ifndef TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ #define TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ -#include +#include #include #include diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index c7e254faacfe..86ed65c4905d 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -21,7 +21,7 @@ * \file src/tir/ir/specialize.cc * \brief Specialize parameters of PrimFunc. */ -#include +#include #include #include #include @@ -432,7 +432,7 @@ PrimFunc Specialize(PrimFunc func, const Map>& pa /**************** FFI ****************/ -TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); +TVM_FFI_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6376d1edfd0b..62baf45bc78e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -21,7 +21,7 @@ * \file tvm/tir/stmt.cc */ #include -#include +#include #include #include #include @@ -52,7 +52,7 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt") +TVM_FFI_REGISTER_GLOBAL("tir.LetStmt") .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { return LetStmt(var, value, body, span); }); @@ -70,7 +70,7 @@ AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, S data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.AttrStmt") +TVM_FFI_REGISTER_GLOBAL("tir.AttrStmt") .set_body_typed([](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to primexpr. if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { @@ -100,7 +100,7 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_GLOBAL("tir.AssertStmt") +TVM_FFI_REGISTER_GLOBAL("tir.AssertStmt") .set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span span) { return AssertStmt(condition, message, body, span); }); @@ -155,7 +155,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.For").set_body_typed( +TVM_FFI_REGISTER_GLOBAL("tir.For").set_body_typed( [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional thread_binding, Optional> annotations, Span span) { return For(loop_var, min, extent, static_cast(kind), body, thread_binding, @@ -199,7 +199,7 @@ While::While(PrimExpr condition, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { return While(condition, body, span); }); @@ -216,7 +216,7 @@ ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, Span span) { return ProducerStore(producer, value, indices, span); }); @@ -267,7 +267,7 @@ int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { return static_cast(result); } -TVM_REGISTER_GLOBAL("tir.Allocate") +TVM_FFI_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { return Allocate(buffer_var, type, extents, condition, body, annotations, span); @@ -328,7 +328,7 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents } return static_cast(result); } -TVM_REGISTER_GLOBAL("tir.AllocateConst") +TVM_FFI_REGISTER_GLOBAL("tir.AllocateConst") .set_body_typed([](Var buffer_var, DataType dtype, Array extents, ObjectRef data_or_idx, Stmt body, Optional> annotations, Span span) { @@ -347,7 +347,7 @@ DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { return DeclBuffer(buffer, body, span); }); @@ -376,7 +376,7 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.ProducerRealize") +TVM_FFI_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { return ProducerRealize(producer, bounds, condition, body, storage_scope, span); @@ -389,7 +389,7 @@ Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { data_ = make_object(buffer, bounds, span); } -TVM_REGISTER_GLOBAL("tir.Prefetch") +TVM_FFI_REGISTER_GLOBAL("tir.Prefetch") .set_body_typed([](Buffer buffer, Array bounds, Span span) { return Prefetch(buffer, bounds, span); }); @@ -423,7 +423,7 @@ SeqStmt::SeqStmt(Array seq, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); @@ -444,7 +444,7 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_c TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_REGISTER_GLOBAL("tir.IfThenElse") +TVM_FFI_REGISTER_GLOBAL("tir.IfThenElse") .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { return IfThenElse(condition, then_case, else_case, span); }); @@ -459,7 +459,7 @@ Evaluate::Evaluate(PrimExpr value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { return Evaluate(value, span); }); @@ -541,7 +541,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferStore") +TVM_FFI_REGISTER_GLOBAL("tir.BufferStore") .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Optional predicate, Span span) { return BufferStore(buffer, value, indices, predicate, span); }); @@ -554,7 +554,7 @@ BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condit data_ = make_object(buffer, bounds, condition, body, span); } -TVM_REGISTER_GLOBAL("tir.BufferRealize") +TVM_FFI_REGISTER_GLOBAL("tir.BufferRealize") .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); @@ -608,7 +608,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { return BufferRegion(buffer, region); } -TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { +TVM_FFI_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { return BufferRegion(buffer, region); }); @@ -665,9 +665,10 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) { - return MatchBufferRegion(buffer, source); -}); +TVM_FFI_REGISTER_GLOBAL("tir.MatchBufferRegion") + .set_body_typed([](Buffer buffer, BufferRegion source) { + return MatchBufferRegion(buffer, source); + }); TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); @@ -689,7 +690,7 @@ Block::Block(Array iter_vars, Array reads, Array iter_vars, Array reads, Array writes, String name_hint, Stmt body, Optional init, Array alloc_buffers, Array match_buffers, @@ -713,7 +714,7 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BlockRealize") +TVM_FFI_REGISTER_GLOBAL("tir.BlockRealize") .set_body_typed([](Array iter_values, PrimExpr predicate, Block block, Span span) { return BlockRealize(iter_values, predicate, block, span); }); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index da7896e40bf0..85d347172702 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -19,8 +19,8 @@ /*! * \file stmt_functor.cc */ +#include #include -#include #include #include #include @@ -892,17 +892,17 @@ PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); } -TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); +TVM_FFI_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); -TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { +TVM_FFI_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); -TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { +TVM_FFI_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }); -TVM_REGISTER_GLOBAL("tir.Substitute") +TVM_FFI_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index f724b6a74598..6a5e1191d219 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -21,9 +21,9 @@ * \file tir/ir/transform.cc * \brief TIR specific transformation passes. */ +#include #include #include -#include #include namespace tvm { @@ -144,7 +144,7 @@ Pass CreatePrimFuncPass(std::function TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); -TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") +TVM_FFI_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index a336688622a4..70614dfeebd7 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -22,7 +22,7 @@ * * builtin intrinsic operators. */ -#include +#include #include #include #include diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 9752c052a161..341a96cae697 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -23,7 +23,7 @@ * Common operator definitions for ops in tir/op.h */ -#include +#include #include #include #include @@ -239,7 +239,7 @@ PrimExpr ret(PrimExpr value, Span span) { return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } -TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); +TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { @@ -761,7 +761,7 @@ PrimExpr bitwise_neg(PrimExpr a, Span span) { return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); } -TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) { return bitwise_neg(a, span); }); @@ -1071,7 +1071,7 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { +TVM_FFI_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (auto opt = args[0].try_cast()) { *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); } else if (auto opt = args[0].try_cast()) { @@ -1082,55 +1082,55 @@ TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi: } }); -TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); +TVM_FFI_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); +TVM_FFI_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); -TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); +TVM_FFI_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); -TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity); +TVM_FFI_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity); -TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); +TVM_FFI_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely); +TVM_FFI_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely); -TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); +TVM_FFI_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); +TVM_FFI_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); -TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); +TVM_FFI_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); -TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); +TVM_FFI_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); +TVM_FFI_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); +TVM_FFI_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); +TVM_FFI_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); -TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); +TVM_FFI_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); -TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_FFI_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); -TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); +TVM_FFI_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ - return (Func(a, b, span)); \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ + return (Func(a, b, span)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ - bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ + bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } \ }) REGISTER_MAKE_BINARY_OP(_OpAdd, add); @@ -1163,12 +1163,12 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor); REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, right_shift); -TVM_REGISTER_GLOBAL("tir._OpIfThenElse") +TVM_FFI_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { return if_then_else(cond, true_value, false_value, span); }); -TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { return const_true(t.lanes(), span); }); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7de3890c4ad8..99f4050a84e5 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -328,7 +328,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; } -TVM_REGISTER_GLOBAL("tir.schedule.IsReductionBlock") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsReductionBlock") .set_body_typed([](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); }); @@ -864,7 +864,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } -TVM_REGISTER_GLOBAL("tir.schedule.GetBlockRealize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetBlockRealize") .set_body_typed([](Schedule sch, BlockRV block_rv) { return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); }); @@ -1483,7 +1483,7 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { return true; } -TVM_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding") .set_body_typed([](Schedule sch, BlockRV block_rv) { return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); }); @@ -1891,8 +1891,8 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, return TensorizeInfo(ret); } -TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); -TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); @@ -2121,19 +2121,20 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); -TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }); -TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); -TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch, BlockRV block) { - auto state = sch->state(); - auto block_sref = sch->GetSRef(block); - return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); -}); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsOutputBlock") + .set_body_typed([](Schedule sch, BlockRV block) { + auto state = sch->state(); + auto block_sref = sch->GetSRef(block); + return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); + }); -TVM_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") .set_body_typed([](Schedule sch, LoopRV loop) -> String { IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); if (kind == kDataPar) { diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 1daea910377f..13b35582eefc 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -238,7 +238,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); } -TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") .set_body_typed([](Buffer buffer, Array indices, Array loops, PrimExpr predicate) { arith::Analyzer analyzer; diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 2a4efe3e8ab4..7fd43c9242f0 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -100,8 +100,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(InstructionNode); TVM_REGISTER_NODE_TYPE(InstructionKindNode); -TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); -TVM_REGISTER_GLOBAL("tir.schedule.Instruction") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.Instruction") .set_body_typed([](InstructionKind kind, Array inputs, Array attrs, Array outputs) -> Instruction { return Instruction(kind, inputs, attrs, outputs); diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 299bc9a62d5a..94db2070c709 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -531,7 +531,7 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, /******** FFI ********/ -TVM_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) { return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); }); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index e457bda5a86a..326d373d6e70 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -1344,7 +1344,7 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ -TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.RegisterReducer") .set_body_typed([](int n_buffers, ffi::Function combiner_getter, ffi::Function identity_getter) { ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 29e40d4003d6..8dc1dcf8dbb2 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -43,35 +43,35 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // .set_body_method(&ScheduleNode::mod); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // .set_body_method(&ScheduleNode::func_working_on); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // .set_body_method(&ScheduleNode::ForkSeed); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // .set_body_method(&ScheduleNode::WorkOn); /**************** (FFI) Constructor ****************/ -TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); -TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); -TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level), enable_check); }); -TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracedSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Traced(mod, seed, debug_mask, @@ -81,7 +81,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") /******** (FFI) Lookup random variables ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGet") .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { if (auto loop_rv = obj.as()) { return self->Get(loop_rv.value()); @@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") << ". Its value is: " << obj; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); @@ -110,7 +110,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") .set_body_typed([](Schedule self, ObjectRef obj) -> void { if (auto loop_rv = obj.as()) { return self->RemoveRV(loop_rv.value()); @@ -126,18 +126,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePartitionedTile") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePartitionedTile") .set_body_method(&ScheduleNode::SamplePartitionedTile); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock").set_body_method(&ScheduleNode::GetBlock); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops").set_body_method(&ScheduleNode::GetLoops); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock").set_body_method(&ScheduleNode::GetBlock); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops").set_body_method(&ScheduleNode::GetLoops); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") .set_body_typed([](Schedule self, ObjectRef rv) { if (auto block_rv = rv.as()) { return self->GetChildBlocks(block_rv.value()); @@ -149,22 +149,22 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") << ". Its value is: " << rv; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") .set_body_method(&ScheduleNode::GetProducers); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") .set_body_method(&ScheduleNode::GetOutputBlocks); /******** (FFI) Transform loops ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") .set_body_method(&ScheduleNode::LoopPartition); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder").set_body_method(&ScheduleNode::Reorder); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorder").set_body_method(&ScheduleNode::Reorder); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") .set_body_method(&ScheduleNode::ReorderBlockIterVar); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { if (auto loop_rv = rv.as()) { return self->AddUnitLoop(loop_rv.value()); @@ -177,48 +177,50 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") } }); /******** (FFI) Manipulate ForKind ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel").set_body_method(&ScheduleNode::Parallel); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize").set_body_method(&ScheduleNode::Vectorize); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleParallel").set_body_method(&ScheduleNode::Parallel); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize").set_body_method(&ScheduleNode::Vectorize); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead").set_body_method(&ScheduleNode::CacheRead); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite").set_body_method(&ScheduleNode::CacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead").set_body_method(&ScheduleNode::CacheRead); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") + .set_body_method(&ScheduleNode::CacheWrite); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") .set_body_method(&ScheduleNode::ReindexCacheRead); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") .set_body_method(&ScheduleNode::ReindexCacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex").set_body_method(&ScheduleNode::CacheIndex); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") + .set_body_method(&ScheduleNode::CacheIndex); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); }); /******** (FFI) Data movement ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt").set_body_method(&ScheduleNode::WriteAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt").set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt").set_body_method(&ScheduleNode::ComputeAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt").set_body_method(&ScheduleNode::ComputeAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") .set_body_method(&ScheduleNode::ReverseComputeAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); /******** (FFI) Reduction ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") .set_body_method(&ScheduleNode::DecomposeReduction); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor").set_body_method(&ScheduleNode::RFactor); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor").set_body_method(&ScheduleNode::RFactor); /******** (FFI) Block annotation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope").set_body_method(&ScheduleNode::SetScope); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope").set_body_method(&ScheduleNode::SetScope); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") .set_body_method(&ScheduleNode::UnsafeSetDType); /******** (FFI) Blockize & Tensorize ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_typed([](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); @@ -227,7 +229,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); @@ -240,7 +242,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") }); /******** (FFI) Annotation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); @@ -252,7 +254,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") << ". Its value is: " << rv; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); @@ -266,7 +268,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") }); /******** (FFI) Layout transformation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map, const Optional& pad_value, bool assume_injective_transform) { @@ -274,9 +276,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") static_cast(buffer_index_type), index_map, pad_value, assume_injective_transform); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") .set_body_method(&ScheduleNode::TransformBlockLayout); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const Array& axis_separators) { return self->SetAxisSeparator( @@ -284,19 +286,19 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") }); /******** (FFI) Padding decomposition ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") .set_body_method(&ScheduleNode::DecomposePadding); -TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum").set_body_method(&ScheduleNode::PadEinsum); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum").set_body_method(&ScheduleNode::PadEinsum); /******** (FFI) Buffer transformation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") .set_body_method(&ScheduleNode::RollingBuffer); /******** (FFI) Misc ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); /******** (FFI) Annotate buffer access ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map) { return self->AnnotateBufferAccess(block_rv, buffer_index, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 8057492dbb04..f2c4b56121c9 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -1012,20 +1012,20 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleState") .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { return ScheduleState(mod, debug_mask, enable_check); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") .set_body_method(&ScheduleStateNode::Replace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional { auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 883bd65ce348..1992f5ae8a69 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -563,13 +563,13 @@ TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_REGISTER_GLOBAL("tir.schedule.Trace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.Trace") .set_body_typed([](Optional> insts, Optional> decisions) { return Trace(insts.value_or(Array()), decisions.value_or({})); }); -TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision").set_body_method(&TraceNode::GetDecision); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceGetDecision").set_body_method(&TraceNode::GetDecision); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAppend") .set_body_typed([](Trace self, Instruction inst, Optional decision) { if (decision.defined()) { return self->Append(inst, decision.value()); @@ -577,14 +577,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") return self->Append(inst); } }); -TVM_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); -TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") .set_body_method(&TraceNode::ApplyToSchedule); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); -TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision").set_body_method(&TraceNode::WithDecision); -TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); -TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceWithDecision").set_body_method(&TraceNode::WithDecision); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") .set_body_typed(Trace::ApplyJSONToSchedule); } // namespace tir diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index fa6a71fce64c..c0929e01a8ad 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -439,7 +439,7 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block return reorder_suffix[0]; } -TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); /******** BlockBufferAccessSimplifier ********/ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { @@ -557,7 +557,7 @@ Optional NormalizePrimFunc(Schedule sch) { return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } -TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805b..f8adcf4f5010 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -21,8 +21,8 @@ * \file annotate_device_regions.cc * \brief Split device function from host. */ +#include #include -#include #include #include #include @@ -74,7 +74,8 @@ Pass AnnotateDeviceRegions() { return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); } -TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions); +TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions") + .set_body_typed(AnnotateDeviceRegions); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 66d0fb61661b..06d596adb44d 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -23,8 +23,8 @@ * Re-write data access to enable memory sharing when possible. */ #include +#include #include -#include #include #include #include diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 616b47f29403..15728e846224 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -23,7 +23,7 @@ // Instrument checkers for out of the bounds access. #include -#include +#include #include #include #include @@ -255,7 +255,7 @@ Pass InstrumentBoundCheckers() { return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") +TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") .set_body_typed(InstrumentBoundCheckers); } // namespace transform diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 18e568c83e74..2ff7c03c6287 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -22,9 +22,9 @@ * * \file combine_context_call.cc */ +#include #include #include -#include #include #include #include @@ -112,7 +112,7 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); +TVM_FFI_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index fb24ebf099eb..42409efb0bd1 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -655,7 +655,7 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { } // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it -TVM_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR); +TVM_FFI_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 543b687df0e8..c5c6accf221a 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -756,7 +756,7 @@ Pass CompactBufferAllocation(bool is_strict) { return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") +TVM_FFI_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") .set_body_typed(CompactBufferAllocation); } // namespace transform diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index ab8d98a00e0e..1b29cea2f27a 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -122,7 +122,8 @@ Pass ConvertBlocksToOpaque() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque") + .set_body_typed(ConvertBlocksToOpaque); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc index d01ae8a45113..4c992163df04 100644 --- a/src/tir/transforms/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -66,7 +66,7 @@ Pass ConvertForLoopsToSerial() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") .set_body_typed(ConvertForLoopsToSerial); } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 5034a858130d..3b382850559a 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -20,7 +20,7 @@ /*! * \file decorate_device_scope.cc */ -#include +#include #include #include #include @@ -44,7 +44,7 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); +TVM_FFI_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 2bf28abd2fcc..398b00092d08 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -162,7 +162,7 @@ Pass DefaultGPUSchedule() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); +TVM_FFI_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); } // namespace transform diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 052f7cf948cb..509efb8d06fd 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -25,8 +25,8 @@ * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0022-tir-non-scalar-constants.md */ #include +#include #include -#include #include #include "ir_utils.h" @@ -105,7 +105,7 @@ tvm::transform::Pass ExtractPrimFuncConstants() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") .set_body_typed(ExtractPrimFuncConstants); } // namespace transform diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index a6da7f7fc407..5ea0a60ea2a8 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -279,7 +279,7 @@ Pass FlattenBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index 86f839c4f5e2..bd33e564e5c2 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -86,7 +86,7 @@ Pass ForceNarrowIndexToInt32() { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") .set_body_typed(ForceNarrowIndexToInt32); } // namespace transform diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index f0fc90ee3244..d1c2155fd066 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -21,7 +21,7 @@ * \file hoist_expression.cc */ #include -#include +#include #include #include #include @@ -552,7 +552,7 @@ Pass HoistExpression() { "tir.HoistExpression"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -587,7 +587,7 @@ Pass HoistIfThenElse() { "tir.HoistIfThenElse"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); Pass HoistIfThenElseBasic() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -607,7 +607,7 @@ Pass HoistIfThenElseBasic() { "tir.HoistIfThenElseBasic"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 52e4d44b615a..6b992ce1f999 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -21,7 +21,7 @@ * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ -#include +#include #include #include #include @@ -319,7 +319,7 @@ Pass InjectDoubleBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index 8a1f4b1ff5a5..00e29061ba3a 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -295,7 +295,7 @@ Pass InjectPermutedLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPermutedLayout").set_body_typed(InjectPermutedLayout); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPermutedLayout").set_body_typed(InjectPermutedLayout); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 5d23e854be02..04bcecac36b0 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -199,7 +199,7 @@ Pass InjectPTXAsyncCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index b4c398bd17eb..c3a6cf50b828 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include #include @@ -123,7 +123,7 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) { // The pass can now be invoked via the pass infrastructure, but we also add a // Python binding for it -TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 03f94e3e9139..ed35bdb0655f 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -34,7 +34,7 @@ https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836 */ #include -#include +#include #include #include @@ -315,7 +315,7 @@ Pass InjectRollingBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index d54223c85fd4..4f137619ea7e 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -1259,7 +1259,8 @@ Pass InjectSoftwarePipeline() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline") + .set_body_typed(InjectSoftwarePipeline); } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index d8df0b0e0509..334d9594616d 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -20,7 +20,7 @@ /*! * \file inject_virtual_thread.cc */ -#include +#include #include #include #include @@ -525,7 +525,7 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); } // namespace transform diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index d14bb05406e7..eae2e29ef686 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -21,7 +21,7 @@ * \file inline_private_functions.cc * \brief Inline private functions to their callsite */ -#include +#include #include #include #include @@ -292,7 +292,8 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions") + .set_body_typed(InlinePrivateFunctions); } // namespace transform diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f62a12bac2ed..0017e97beb88 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -850,7 +850,7 @@ Pass ConvertSSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 7a9a1f59977b..b30a47c84fe9 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -183,7 +183,7 @@ Pass LiftThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 3a0b80921ff9..1adc0f6b043a 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include @@ -810,7 +810,7 @@ Pass LoopPartition() { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); } // namespace transform diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index e1ec0f1572c7..c3358e1c9207 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -175,7 +175,7 @@ Pass LowerAsyncDMA() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 37a31d204427..31d7f91d74e9 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -934,7 +934,7 @@ Pass LowerCrossThreadReduction() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") .set_body_typed(LowerCrossThreadReduction); } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 19f529103b5c..dbc529cfeabd 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -21,7 +21,7 @@ * \brief Pass for lowering custom datatypes */ -#include +#include #include #include #include @@ -249,7 +249,7 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); } // namespace transform diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index bb63aeb0a337..2ca0e6d92f68 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -21,8 +21,8 @@ * \file lower_device_kernel_launch.cc * \brief Split device function from host. */ +#include #include -#include #include #include #include @@ -369,7 +369,7 @@ Pass LowerDeviceKernelLaunch() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") .set_body_typed(LowerDeviceKernelLaunch); } // namespace transform diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 8951f70a6c5b..a30232b9ce80 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -22,7 +22,7 @@ * \brief Lower the special device storage access. */ #include -#include +#include #include #include #include @@ -130,7 +130,7 @@ Pass LowerDeviceStorageAccessInfo() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index 8b450784a020..03188fb6c907 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -79,7 +79,7 @@ Pass LowerInitBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); } // namespace transform diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 8e141128f67d..8fe9bedce9f0 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -21,7 +21,7 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ -#include +#include #include #include #include @@ -394,7 +394,7 @@ Pass LowerIntrin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); } // namespace transform diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 3c2c6b67e653..6e2ea5bc14af 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -267,7 +267,7 @@ Pass LowerMatchBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); } // namespace transform diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 9939ac9dec3e..f3551987426d 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -213,7 +213,7 @@ Pass LowerOpaqueBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 34b99fea0782..0d2092338228 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -22,7 +22,7 @@ * \file lower_thread_allreduce.cc */ #include -#include +#include #include #include #include @@ -809,7 +809,7 @@ Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index a0478325e18a..095bd321c937 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -21,7 +21,7 @@ * Lower TVM related builtin intrinsics such as packed call. * \file tir/transforms/lower_tvm_buildin.cc */ -#include +#include #include #include #include @@ -673,7 +673,7 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc index 0b5f7bf1554d..eac2a21b4917 100644 --- a/src/tir/transforms/lower_vtcm_alloc.cc +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -72,7 +72,7 @@ Pass LowerVtcmAlloc() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); } // namespace transform diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 4a364c0ecb8b..b1642bef3c92 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -27,7 +27,7 @@ // explaining the concept of warp shuffle. #include #include -#include +#include #include #include #include @@ -461,7 +461,7 @@ Pass LowerWarpMemory() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 8ab5ff7a3fa8..83ce75cead19 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,8 +20,8 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include -#include #include #include #include @@ -438,7 +438,9 @@ Pass MakePackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { return MakePackedAPI(); }); +TVM_FFI_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { + return MakePackedAPI(); +}); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 26a40187fb86..a72d68972735 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -20,8 +20,8 @@ /*! * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. */ +#include #include -#include #include #include #include @@ -200,7 +200,7 @@ Pass MakeUnpackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); +TVM_FFI_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 885d5917136d..dc9420f728be 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -275,7 +275,7 @@ Pass ManifestSharedMemoryLocalStage() { return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") .set_body_typed(ManifestSharedMemoryLocalStage); } // namespace transform diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 6d35cc5ac2d1..916c5c84e9af 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include #include @@ -776,7 +776,7 @@ Pass LowerAutoCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h index e8dc22be4f50..46c9a97c527d 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -20,7 +20,7 @@ #define TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ #include -#include +#include #include #include #include diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 85f102cb4177..52966e005aaa 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -23,7 +23,7 @@ * This pass merges multiple TIR-level dynamic or static shared memory allocations into one * allocation. */ -#include +#include #include #include #include @@ -695,7 +695,7 @@ Pass MergeSharedMemoryAllocations() { return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") +TVM_FFI_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") .set_body_typed(MergeSharedMemoryAllocations); } // namespace transform diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 696eae201f3c..8183b2fd8f45 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,7 +22,7 @@ * \brief narrow the datatype of indexing vars */ -#include +#include #include #include #include @@ -320,7 +320,7 @@ Pass NarrowDataType(int target_bits) { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); +TVM_FFI_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 0845f4e3c5f5..c141ef33c289 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -257,7 +257,7 @@ Pass PlanAndUpdateBufferAllocationLocation() { return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); } -TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") +TVM_FFI_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") .set_body_typed(PlanAndUpdateBufferAllocationLocation); } // namespace transform diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 7efa23bc322d..ade1aea7c941 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -109,9 +109,9 @@ transform::Pass Filter(ffi::TypedFunction fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); -TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); -TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); +TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); +TVM_FFI_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index 7f6930e2e2bf..f8548ca59a7d 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -283,7 +283,7 @@ Pass InstrumentProfileIntrinsics() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics") +TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics") .set_body_typed(InstrumentProfileIntrinsics); } // namespace transform diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index 0c3f7a9ba32f..0593c4f812fe 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -169,7 +169,7 @@ Pass ReduceBranchingThroughOvercompute() { return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") .set_body_typed(ReduceBranchingThroughOvercompute); } // namespace transform diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 3e0e8a112169..6afaa0c61583 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -20,7 +20,7 @@ /*! * \file remap_thread_axis.cc */ -#include +#include #include #include #include @@ -103,7 +103,7 @@ Pass RemapThreadAxis(Map thread_map) { return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc index 928bcf02bc1b..ce7176e8cc46 100644 --- a/src/tir/transforms/remove_assume.cc +++ b/src/tir/transforms/remove_assume.cc @@ -21,7 +21,7 @@ * \file remove_store_undef.cc * \brief Remove stores of tir::builtin::undef */ -#include +#include #include #include #include @@ -61,7 +61,7 @@ Pass RemoveAssume() { return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); } // namespace transform diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 3b418aac0cf5..49dd41ae86a6 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -22,7 +22,7 @@ * \brief Remove no op from the stmt */ #include -#include +#include #include #include #include @@ -331,7 +331,7 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); } // namespace transform diff --git a/src/tir/transforms/remove_store_undef.cc b/src/tir/transforms/remove_store_undef.cc index 6b28cb165aa9..31b4a558c600 100644 --- a/src/tir/transforms/remove_store_undef.cc +++ b/src/tir/transforms/remove_store_undef.cc @@ -21,7 +21,7 @@ * \file remove_store_undef.cc * \brief Remove stores of tir::builtin::undef */ -#include +#include #include #include #include @@ -171,7 +171,7 @@ Pass RemoveStoreUndef() { "tir.RemoveStoreUndef"); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); } // namespace transform diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index e8d89bfb5700..881f321bf673 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -285,7 +285,7 @@ Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock") +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock") .set_body_typed(RemoveWeightLayoutRewriteBlock); } // namespace transform diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index db5bae8dfad5..cd1517b11c2a 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -290,7 +290,7 @@ class RenewDefMutator : public StmtExprMutator { PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } -TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); +TVM_FFI_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index beb5997d4982..0fb24c62500a 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -21,7 +21,7 @@ * \file renormalize_split_pattern.cc * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) */ -#include +#include #include #include #include @@ -205,7 +205,7 @@ Pass RenormalizeSplitPattern() { return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") +TVM_FFI_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") .set_body_typed(RenormalizeSplitPattern); } // namespace transform diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 7646d01f8e90..624e2d9921a9 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -21,7 +21,7 @@ * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ -#include +#include #include #include #include @@ -139,7 +139,7 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 71843f1cf401..82c5d4178401 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -25,7 +25,7 @@ #include "../../tir/transforms/simplify.h" #include -#include +#include #include #include #include @@ -359,7 +359,7 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); +TVM_FFI_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index d9cd6d35497c..98aea3da99d5 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include #include #include @@ -47,7 +47,7 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); +TVM_FFI_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index adf00f0b57c4..0b12bd02d482 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -21,9 +21,9 @@ * \file split_host_device.cc * \brief Split device function from host. */ +#include #include #include -#include #include #include #include @@ -168,7 +168,7 @@ Pass SplitHostDevice() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); +TVM_FFI_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 22c347066789..b8062e2a2f10 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -23,8 +23,8 @@ * Re-write data access to enable memory sharing when possible. */ #include +#include #include -#include #include #include #include @@ -1761,7 +1761,7 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); +TVM_FFI_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1770,7 +1770,7 @@ Pass PointerValueTypeRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") +TVM_FFI_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") .set_body_typed(PointerValueTypeRewrite); } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index e0ae7172ad5c..3c6a6fc9be86 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -21,7 +21,7 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ -#include +#include #include #include #include @@ -217,7 +217,7 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index fd772863f780..34878d1b333d 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -20,7 +20,7 @@ /*! * \file thread_storage_sync.cc */ -#include +#include #include #include #include @@ -471,7 +471,7 @@ Pass ThreadSync(String storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 5332fcfff123..3ef35d74cf8a 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -184,7 +184,7 @@ Pass TransformMmaBufferLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); } -TVM_REGISTER_GLOBAL("tir.transform.TransformMmaBufferLayout") +TVM_FFI_REGISTER_GLOBAL("tir.transform.TransformMmaBufferLayout") .set_body_typed(TransformMmaBufferLayout); } // namespace transform diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 664bb9a1633c..08fc921f4ebf 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -199,7 +199,7 @@ Pass UnifyThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); +TVM_FFI_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); } // namespace transform diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a68ebe7e02ff..7218adbda216 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -23,7 +23,7 @@ */ // Unrolls the loop as in Halide pipeline. #include -#include +#include #include #include #include @@ -288,7 +288,7 @@ Pass UnrollLoop() { return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); +TVM_FFI_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); } // namespace transform diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index da9d74abbe6a..c4d2d4608044 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -21,7 +21,7 @@ * \file unsupported_dtype_legalize.cc * \brief legalize bf16/fp8 type by adding cast_to_fp32 */ -#include +#include #include #include #include @@ -758,7 +758,7 @@ Pass BF16ComputeLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -771,7 +771,7 @@ Pass BF16StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -784,7 +784,7 @@ Pass FP8ComputeLegalize(String promote_dtype_str) { return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -797,7 +797,7 @@ Pass FP8StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 81b906786290..a1195cfef81f 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -381,7 +381,7 @@ Pass UseAssumeToReduceBranches() { return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") +TVM_FFI_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") .set_body_typed(UseAssumeToReduceBranches); } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 54b2daf83632..16aae03932cf 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -22,7 +22,7 @@ */ // Loop vectorizer as in Halide pipeline. #include -#include +#include #include #include #include @@ -1028,7 +1028,7 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); } // namespace transform diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 6d6dc4edc5f6..2a868145c94e 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -21,8 +21,8 @@ * \brief Registration of broadcast operators * \file broadcast.cc */ +#include #include -#include #include #include @@ -32,19 +32,19 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ - bool lhs_is_tensor = args[0].as().has_value(); \ - bool rhs_is_tensor = args[1].as().has_value(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_FFI_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ + bool lhs_is_tensor = args[0].as().has_value(); \ + bool rhs_is_tensor = args[1].as().has_value(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } \ }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); @@ -73,9 +73,10 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.broadcast_to") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = broadcast_to(args[0].cast(), args[1].cast>()); + }); } // namespace topi } // namespace tvm diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index e4ac103f14d6..40c8332ab725 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -355,7 +355,7 @@ Array InferEinsumShape(const std::string& subscripts, return einsum_builder.InferShape(); } -TVM_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = einsum(args[0].cast(), args[1].cast>()); }); diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index e3a3411a9c6c..05e59b971371 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -21,8 +21,8 @@ * \brief Registration of elemwise operators * \file elemwise.cc */ +#include #include -#include #include namespace tvm { @@ -31,139 +31,140 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sigmoid(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = rsqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log2(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log10(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = identity(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = negative(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = clip(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cast(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reinterpret(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = elemwise_sum(args[0].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.elemwise_sum") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = elemwise_sum(args[0].cast>()); + }); -TVM_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sign(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full_like(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = logical_not(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = bitwise_not(args[0].cast()); }); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index e4c9ae5f60e1..68ba43090ac9 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -21,8 +21,8 @@ * \brief Registration of NN operators * \file nn.cc */ +#include #include -#include #include #include #include @@ -45,127 +45,131 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = relu(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = leaky_relu(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.leaky_relu") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = leaky_relu(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = prelu(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = pad(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") +TVM_FFI_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = space_to_batch_nd(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") +TVM_FFI_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = batch_to_space_nd(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nll_loss(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast()); }); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dense(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); }); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::flatten(args[0].cast()); }); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") +TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::scale_shift_nchw(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") +TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), args[2].cast()); }); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - nn::pool_grad(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), static_cast(args[5].cast()), - args[6].cast(), args[7].cast(), args[8].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool_grad") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool_grad(args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::global_pool(args[0].cast(), static_cast(args[1].cast()), - args[2].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.global_pool") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::global_pool(args[0].cast(), + static_cast(args[1].cast()), + args[2].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool1d(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool3d(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool1d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool2d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool3d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), @@ -173,45 +177,49 @@ TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, f }); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::softmax(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::log_softmax(args[0].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.log_softmax") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::log_softmax(args[0].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast()); }); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") +TVM_FFI_REGISTER_GLOBAL("topi.nn.binarize_pack") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::binary_dense(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.binary_dense") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::binary_dense(args[0].cast(), args[1].cast()); + }); /* Ops from nn/layer_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), - args[4].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.layer_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::layer_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast>(), + args[4].cast()); + }); /* Ops from nn/group_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::group_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.group_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::group_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast(), + args[5].cast>(), args[6].cast()); + }); /* Ops from nn/instance_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.instance_norm") +TVM_FFI_REGISTER_GLOBAL("topi.nn.instance_norm") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast>(), @@ -219,7 +227,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm") }); /* Ops from nn/rms_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), args[2].cast>(), args[3].cast()); }); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index e1720cc0b6b0..1720ddd60230 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -21,8 +21,8 @@ * \brief Registration of reduction operators * \file reduction.cc */ +#include #include -#include #include #include @@ -32,43 +32,44 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::sum(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::min(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::max(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmin(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmax(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::prod(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::all(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::any(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.collapse_sum") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + }); } // namespace topi } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 450aded9459f..50aa50638266 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -21,8 +21,8 @@ * \brief Registration of transform operators * \file transform.cc */ +#include #include -#include #include #include #include @@ -37,55 +37,57 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = expand_dims(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), args[1].cast>>()); }); -TVM_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // pass empty seq_lengths tensor to reverse_sequence *rv = reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reverse_sequence") +TVM_FFI_REGISTER_GLOBAL("topi.reverse_sequence") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reverse_sequence(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reshape(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.sliding_window").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sliding_window") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = sliding_window(args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast>()); + }); -TVM_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = squeeze(args[0].cast(), ArrayOrInt(args[1])); }); -TVM_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = concatenate(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = stack(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = shape(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = ndarray_size(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.ndarray_size") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ndarray_size(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args[1].try_cast()) { *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); } else { @@ -94,13 +96,13 @@ TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi:: } }); -TVM_REGISTER_GLOBAL("topi.layout_transform") +TVM_FFI_REGISTER_GLOBAL("topi.layout_transform") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = layout_transform(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 4) { auto mode = args[3].cast(); int batch_dims = args[2].cast(); @@ -115,52 +117,55 @@ TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::A } }); -TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - double pad_val = args[2].cast(); - int axis = args[3].cast(); - *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sequence_mask") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + double pad_val = args[2].cast(); + int axis = args[3].cast(); + *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); + }); -TVM_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = where(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = arange(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = meshgrid(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = repeat(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tile(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = gather(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int batch_dims = args[2].cast(); *rv = gather_nd(args[0].cast(), args[1].cast(), batch_dims); }); -TVM_REGISTER_GLOBAL("topi.unravel_index").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = unravel_index(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.unravel_index") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = unravel_index(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sparse_to_dense") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), + args[2].cast(), args[3].cast()); + }); -TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { switch (args.size()) { case 2: *rv = matmul(args[0].cast(), args[1].cast()); @@ -177,7 +182,7 @@ TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi: } }); -TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 2) { *rv = tensordot(args[0].cast(), args[1].cast()); } else if (args.size() == 3) { @@ -189,34 +194,36 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, f } }); -TVM_REGISTER_GLOBAL("topi.strided_slice").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Tensor x = args[0].cast(); - Array begin = args[1].cast>(); - Array end = args[2].cast>(); - Array strides = args[3].cast>(); - Array axes = args[4].cast>(); - bool assume_inbound = args[6].cast(); - if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && - IsConstIntArray(x->shape)) { - Array begin_static = args[1].cast>(); - Array end_static = args[2].cast>(); - Array strides_static = args[3].cast>(); - auto slice_mode = args[5].cast(); - if (axes.size()) { - *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); - } else { - *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); - } - } else { - if (axes.size()) { - *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); - } else { - *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); - } - } -}); +TVM_FFI_REGISTER_GLOBAL("topi.strided_slice") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Tensor x = args[0].cast(); + Array begin = args[1].cast>(); + Array end = args[2].cast>(); + Array strides = args[3].cast>(); + Array axes = args[4].cast>(); + bool assume_inbound = args[6].cast(); + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && + IsConstIntArray(x->shape)) { + Array begin_static = args[1].cast>(); + Array end_static = args[2].cast>(); + Array strides_static = args[3].cast>(); + auto slice_mode = args[5].cast(); + if (axes.size()) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, + slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + if (axes.size()) { + *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); + } else { + *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); + } + } + }); -TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice") +TVM_FFI_REGISTER_GLOBAL("topi.dynamic_strided_slice") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { te::Tensor begin = args[1].cast(); te::Tensor end = args[2].cast(); @@ -224,13 +231,13 @@ TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice") *rv = dynamic_strided_slice(args[0].cast(), begin, end, strides); }); -TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") +TVM_FFI_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") .set_body_typed([](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, Array output_shape) { return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); }); -TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int depth = args[3].cast(); int axis = args[4].cast(); DataType dtype = args[5].cast(); @@ -238,18 +245,18 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi depth, axis, dtype); }); -TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int k1 = args[2].cast(); - int k2 = args[3].cast(); - bool super_diag_right_align = args[4].cast(); - bool sub_diag_right_align = args[5].cast(); - *rv = matrix_set_diag(args[0].cast(), args[1].cast(), k1, k2, - super_diag_right_align, sub_diag_right_align); -}); +TVM_FFI_REGISTER_GLOBAL("topi.matrix_set_diag") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + int k1 = args[2].cast(); + int k2 = args[3].cast(); + bool super_diag_right_align = args[4].cast(); + bool sub_diag_right_align = args[5].cast(); + *rv = matrix_set_diag(args[0].cast(), args[1].cast(), k1, k2, + super_diag_right_align, sub_diag_right_align); + }); -TVM_REGISTER_GLOBAL("topi.adv_index").set_body_typed([](te::Tensor x, Array indices) { - return adv_index(x, indices); -}); +TVM_FFI_REGISTER_GLOBAL("topi.adv_index") + .set_body_typed([](te::Tensor x, Array indices) { return adv_index(x, indices); }); } // namespace topi } // namespace tvm diff --git a/src/topi/utils.cc b/src/topi/utils.cc index c02744a4202d..66da512a66e0 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -22,25 +22,25 @@ * \file utils.cc */ +#include #include -#include #include namespace tvm { namespace topi { -TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape") +TVM_FFI_REGISTER_GLOBAL("topi.utils.is_empty_shape") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::detail::is_empty_shape(args[0].cast>()); }); -TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") +TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nchw(args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") +TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nhwc(args[0].cast(), args[1].cast>(), diff --git a/src/topi/vision.cc b/src/topi/vision.cc index dca44bf86c3c..844f8f94592e 100644 --- a/src/topi/vision.cc +++ b/src/topi/vision.cc @@ -21,8 +21,8 @@ * \brief Registration of vision operators * \file vision.cc */ +#include #include -#include #include namespace tvm { @@ -31,9 +31,10 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = vision::reorg(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.vision.reorg") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = vision::reorg(args[0].cast(), args[1].cast()); + }); } // namespace topi } // namespace tvm diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index fa2a4aa45895..313b149e0987 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include @@ -38,7 +38,7 @@ namespace tvm { namespace runtime { namespace hexagon { -TVM_REGISTER_GLOBAL("hexagon.run_all_tests") +TVM_FFI_REGISTER_GLOBAL("hexagon.run_all_tests") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using delimiter and build argument vector diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index d9331db28bee..9b55151638e6 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include @@ -80,7 +80,7 @@ class GtestPrinter : public testing::EmptyTestEventListener { std::string GetOutput() { return gtest_out_.str(); } }; -TVM_REGISTER_GLOBAL("hexagon.run_unit_tests") +TVM_FFI_REGISTER_GLOBAL("hexagon.run_unit_tests") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using delimiter and build argument vector diff --git a/tests/cpp-runtime/opencl/texture_copy_test.cc b/tests/cpp-runtime/opencl/texture_copy_test.cc index 75147bb5571d..61d9044b6d86 100644 --- a/tests/cpp-runtime/opencl/texture_copy_test.cc +++ b/tests/cpp-runtime/opencl/texture_copy_test.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include diff --git a/tests/cpp/llvm_codegen_registry_test.cc b/tests/cpp/llvm_codegen_registry_test.cc index 534d4c8e411b..49457fd0dac5 100644 --- a/tests/cpp/llvm_codegen_registry_test.cc +++ b/tests/cpp/llvm_codegen_registry_test.cc @@ -21,8 +21,8 @@ #include #include +#include #include -#include #include diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 28300dfdea4e..955cd58dc2ae 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -80,7 +80,7 @@ Which eventually jumps to the following line in C++, which creates a RPC client [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) ```cpp -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); @@ -94,7 +94,7 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi: [https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106) ```cpp -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto session_name = args[0].cast(); int remote_stack_size_bytes = args[1].cast(); diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 9ba89fe7f535..d06846400dca 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -410,8 +410,8 @@ def run_arr_test(): def test_rpc_return_remote_object(): def check(client, is_local): make_shape = client.get_function("ffi.Shape") - get_elem = client.get_function("runtime.GetShapeElem") - get_size = client.get_function("runtime.GetShapeSize") + get_elem = client.get_function("testing.GetShapeElem") + get_size = client.get_function("testing.GetShapeSize") shape = make_shape(2, 3) assert shape.type_key == "runtime.RPCObjectRef" assert get_elem(shape, 0) == 2 @@ -662,7 +662,7 @@ def test_compiled_function_with_zero_arguments(call_with_unused_argument): """RPC functions do not require an argument This is a regression test. When no arguments are provided, RPC - provides NULL as the `TVMValue* args` argument to a PackedFunc. + provides NULL as the `TVMFFIAny* args` argument to a PackedFunc. However, previous implementations of `MakePackedAPI` unconditionally asserted that the `args` pointer was non-null. This assertion is now generated only when the function accepts diff --git a/version.py b/version.py index db1f6cadfb8c..a8ae77a8d5f2 100644 --- a/version.py +++ b/version.py @@ -21,7 +21,7 @@ List of affected files: - tvm-root/python/tvm/_ffi/libinfo.py -- tvm-root/include/tvm/runtime/c_runtime_api.h +- tvm-root/include/tvm/runtime/base.h - tvm-root/conda/recipe/meta.yaml - tvm-root/web/package.json """ @@ -179,7 +179,7 @@ def sync_version(pub_ver, local_ver, dry_run): # Note that full git hash is already available in libtvm # C++ header update( - os.path.join(PROJ_ROOT, "include", "tvm", "runtime", "c_runtime_api.h"), + os.path.join(PROJ_ROOT, "include", "tvm", "runtime", "base.h"), r'(?<=TVM_VERSION ")[.0-9a-z\+]+', pub_ver, dry_run, diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 1e35a1137fb7..922b25b0d74b 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -58,7 +58,7 @@ TVM_DLL void TVMWasmFreeSpace(void* data); * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer 3A * \return 0 if success. */ -TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out); +TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFFIObjectHandle* out); /*! * \brief Get the last error message. @@ -94,7 +94,7 @@ void* TVMWasmAllocSpace(int size) { void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } -int TVMFFIWasmFunctionCreate(void* self, TVMFunctionHandle* out) { +int TVMFFIWasmFunctionCreate(void* self, TVMFFIObjectHandle* out) { return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall, TVMFFIWasmFunctionDeleter, out); } diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 728e1c648c28..40dfb31ad19f 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -32,18 +32,15 @@ #include -#include "src/runtime/c_runtime_api.cc" -#include "src/runtime/container.cc" #include "src/runtime/contrib/sort/sort.cc" #include "src/runtime/cpu_device_api.cc" +#include "src/runtime/device_api.cc" #include "src/runtime/file_utils.cc" #include "src/runtime/library_module.cc" #include "src/runtime/logging.cc" #include "src/runtime/module.cc" #include "src/runtime/ndarray.cc" -#include "src/runtime/object.cc" #include "src/runtime/profiling.cc" -#include "src/runtime/registry.cc" #include "src/runtime/rpc/rpc_channel.cc" #include "src/runtime/rpc/rpc_endpoint.cc" #include "src/runtime/rpc/rpc_event_impl.cc"