|
12 | 12 | #import <coreml_backend/delegate.h>
|
13 | 13 | #import <executorch/runtime/core/evalue.h>
|
14 | 14 | #import <executorch/runtime/platform/log.h>
|
| 15 | +#import <executorch/runtime/kernel/kernel_includes.h> |
15 | 16 | #import <memory>
|
16 | 17 | #import <model_event_logger.h>
|
17 | 18 | #import <model_logging_options.h>
|
18 | 19 | #import <multiarray.h>
|
19 | 20 | #import <objc_safe_cast.h>
|
20 | 21 | #import <unordered_map>
|
21 | 22 | #import <vector>
|
| 23 | +#include <array> |
22 | 24 |
|
23 | 25 | #ifdef ET_EVENT_TRACER_ENABLED
|
24 | 26 | #import <model_event_logger_impl.h>
|
|
40 | 42 | using executorch::runtime::FreeableBuffer;
|
41 | 43 | using executorch::runtime::get_backend_class;
|
42 | 44 | using executorch::runtime::Result;
|
| 45 | +using executorch::aten::SizesType; |
| 46 | +using executorch::aten::Tensor; |
| 47 | +using executorch::runtime::kTensorDimensionLimit; |
43 | 48 |
|
44 | 49 | std::optional<MultiArray::DataType> get_data_type(ScalarType scalar_type) {
|
45 | 50 | switch (scalar_type) {
|
@@ -221,6 +226,21 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
|
221 | 226 | ETCoreMLStrings.delegateIdentifier.UTF8String);
|
222 | 227 | #endif
|
223 | 228 |
|
| 229 | + // Resize for dynamic shape |
| 230 | + std::array<SizesType, kTensorDimensionLimit> new_shape; |
| 231 | + for (size_t i = nInputs; i < nInputs + nOutputs; i++) { |
| 232 | + Tensor& t = args[i]->toTensor(); |
| 233 | + int rank = delegate_args[i].layout().rank(); |
| 234 | + assert (rank <= new_shape.size()); |
| 235 | + for (int d = 0; d < rank; d++) { |
| 236 | + new_shape[d] = delegate_args[i].layout().shape()[d]; |
| 237 | + } |
| 238 | + ET_CHECK_OR_RETURN_ERROR( |
| 239 | + resize_tensor(t, ArrayRef(new_shape.data(), rank)) == Error::Ok, |
| 240 | + DelegateInvalidHandle, |
| 241 | + "%s: Failed to resize delegate output %zu", ETCoreMLStrings.delegateIdentifier.UTF8String, i); |
| 242 | + } |
| 243 | + |
224 | 244 | return Error::Ok;
|
225 | 245 | }
|
226 | 246 |
|
|
0 commit comments