Skip to content

Commit c1aec15

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix_bug3
2 parents 0e85916 + e771a76 commit c1aec15

File tree

205 files changed

+3589
-2339
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

205 files changed

+3589
-2339
lines changed

cgmanifests/generate_cgmanifest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def normalize_path_separators(path):
115115
submodule_lines = proc.stdout.splitlines()
116116
for submodule_line in submodule_lines:
117117
(absolute_path, url, commit) = submodule_line.split(" ")
118-
git_deps[GitDep(commit, url)] = "git submodule at {}".format(
119-
normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))
118+
git_deps[GitDep(commit, url)] = (
119+
f"git submodule at {normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))}"
120120
)
121121

122122
with open(os.path.join(SCRIPT_DIR, "..", "cmake", "deps.txt")) as f:

cmake/external/onnxruntime_external_deps.cmake

+7-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ if (onnxruntime_BUILD_UNIT_TESTS)
3737
set(gtest_disable_pthreads ON)
3838
endif()
3939
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
40-
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
41-
# Needs to update onnxruntime/test/xctest/xcgtest.mm
40+
if (IOS OR ANDROID)
41+
# on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing
42+
# any args to gtest executables, such as using --gtest_filter to debug a specific test.
43+
# Processing of compile definitions:
44+
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21
45+
# If set, this code throws away the flag and does nothing on registration, which results in no flags being known:
46+
# https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217
4247
set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE)
4348
else()
4449
set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE)

cmake/onnxruntime_graph.cmake

+30-23
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,26 @@ file(GLOB_RECURSE onnxruntime_graph_src CONFIGURE_DEPENDS
77
"${ONNXRUNTIME_ROOT}/core/graph/*.cc"
88
)
99

10-
# create empty list for any excludes
10+
# start with empty training srcs list
11+
set(orttraining_graph_src)
12+
13+
if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
14+
set(orttraining_graph_src
15+
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
16+
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
17+
)
18+
endif()
19+
20+
if (onnxruntime_ENABLE_TRAINING)
21+
file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
22+
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
23+
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc"
24+
)
25+
endif()
26+
27+
# create empty lists for any excludes
1128
set(onnxruntime_graph_src_exclude_patterns)
29+
set(orttraining_graph_src_exclude_patterns)
1230

1331
if (onnxruntime_MINIMAL_BUILD)
1432
# remove schema registration support
@@ -22,11 +40,18 @@ if (onnxruntime_MINIMAL_BUILD)
2240
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc"
2341
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h"
2442
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc"
43+
"${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.h"
44+
"${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.cc"
2545
"${ONNXRUNTIME_ROOT}/core/graph/function_template.h"
2646
"${ONNXRUNTIME_ROOT}/core/graph/function_utils.h"
2747
"${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc"
2848
)
2949

50+
list(APPEND orttraining_graph_src_exclude_patterns
51+
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
52+
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
53+
)
54+
3055
# no Function support initially
3156
list(APPEND onnxruntime_graph_src_exclude_patterns
3257
"${ONNXRUNTIME_ROOT}/core/graph/function*"
@@ -64,30 +89,12 @@ endif()
6489
file(GLOB onnxruntime_graph_src_exclude ${onnxruntime_graph_src_exclude_patterns})
6590
list(REMOVE_ITEM onnxruntime_graph_src ${onnxruntime_graph_src_exclude})
6691

67-
file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS
68-
"${ONNXRUNTIME_ROOT}/core/defs/*.cc"
69-
)
70-
71-
if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
72-
set(orttraining_graph_src
73-
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
74-
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
75-
)
76-
endif()
77-
78-
if (onnxruntime_ENABLE_TRAINING)
79-
file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
80-
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
81-
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc"
82-
)
83-
endif()
84-
85-
set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
8692
if (onnxruntime_ENABLE_TRAINING_OPS)
87-
list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src})
93+
file(GLOB orttraining_graph_src_exclude ${orttraining_graph_src_exclude_patterns})
94+
list(REMOVE_ITEM orttraining_graph_src ${orttraining_graph_src_exclude})
8895
endif()
8996

90-
onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_lib_src})
97+
onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_src} ${orttraining_graph_src})
9198
add_dependencies(onnxruntime_graph onnx_proto flatbuffers::flatbuffers)
9299
onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common ${WIL_TARGET} onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11)
93100

@@ -120,7 +127,7 @@ endif()
120127

121128
set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")
122129
set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX)
123-
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
130+
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src})
124131
if (onnxruntime_ENABLE_TRAINING_OPS)
125132
source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src})
126133
endif()

cmake/onnxruntime_providers_coreml.cmake

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$")
7070
source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs})
7171

7272
# These are shared utils,
73-
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
74-
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
73+
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
74+
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
7575
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
7676
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
7777
)

cmake/onnxruntime_providers_nnapi.cmake

+2-4
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,10 @@
4949
endif()
5050

5151
# These are shared utils,
52-
# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
52+
# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
5353
list(APPEND onnxruntime_provider_nnapi_cc_src_patterns
5454
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
5555
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
56-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
57-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
5856
)
5957

6058
file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns})
@@ -81,4 +79,4 @@
8179
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
8280
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
8381
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
84-
endif()
82+
endif()

cmake/onnxruntime_providers_qnn.cmake

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
add_compile_definitions(USE_QNN=1)
55

66
# These are shared utils,
7-
# TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML
8-
file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
7+
# TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML
8+
file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
99
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
1010
"${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
11-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
12-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
1311
)
1412

1513
file(GLOB_RECURSE
@@ -42,4 +40,4 @@
4240
# ignore the warning unknown-pragmas on "pragma region"
4341
if(NOT MSVC)
4442
target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas")
45-
endif()
43+
endif()

cmake/onnxruntime_providers_xnnpack.cmake

-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
"${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h"
88
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h"
99
"${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc"
10-
# utils for handling QDQ models
11-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
12-
"${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
1310
)
1411

1512
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})

docs/ORTModule_Convergence_Notes.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
8989
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
9090

9191
```diff
92-
+ from onnxruntime.training.utils import inspect_activation
92+
+ from onnxruntime.training.utils.hooks import inspect_activation
9393
class BloomForCausalLM(BloomPreTrainedModel):
9494
def __init__(self, config: BloomConfig):
9595
...

docs/python/examples/plot_train_convert_predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def loop(X_test, fct, n=None):
134134
nrow = X_test.shape[0]
135135
if n is None:
136136
n = nrow
137-
for i in range(0, n):
137+
for i in range(n):
138138
im = i % nrow
139139
fct(X_test[im : im + 1])
140140

include/onnxruntime/core/session/onnxruntime_c_api.h

+4
Original file line numberDiff line numberDiff line change
@@ -3619,6 +3619,10 @@ struct OrtApi {
36193619
* - "73"
36203620
* - "75"
36213621
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
3622+
"enable_htp_fp16_precision": Only used for float32 model.
3623+
Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
3624+
- "0": Default. With fp32 precision.
3625+
- "1": With fp16 precision.
36223626
*
36233627
* SNPE supported keys:
36243628
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",

js/common/lib/env.ts

+35
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,44 @@ export declare namespace Env {
143143
*/
144144
ondata?: (data: WebGpuProfilingData) => void;
145145
};
146+
/**
147+
* Set or get the power preference.
148+
*
149+
* Setting this property only has effect before the first WebGPU inference session is created. The value will be
150+
* used as options for `navigator.gpu.requestAdapter()`.
151+
*
152+
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
153+
*
154+
* @defaultValue `undefined`
155+
*/
156+
powerPreference?: 'low-power'|'high-performance';
157+
/**
158+
* Set or get the force fallback adapter flag.
159+
*
160+
* Setting this property only has effect before the first WebGPU inference session is created. The value will be
161+
* used as options for `navigator.gpu.requestAdapter()`.
162+
*
163+
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
164+
*
165+
* @defaultValue `undefined`
166+
*/
167+
forceFallbackAdapter?: boolean;
168+
/**
169+
* Get the adapter for WebGPU.
170+
*
171+
* This property is only available after the first WebGPU inference session is created.
172+
*
173+
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
174+
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
175+
*
176+
* see comments on {@link GpuBufferType}
177+
*/
178+
readonly adapter: unknown;
146179
/**
147180
* Get the device for WebGPU.
148181
*
182+
* This property is only available after the first WebGPU inference session is created.
183+
*
149184
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
150185
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
151186
*

js/web/lib/wasm/jsep/backend-webgpu.ts

+24-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view';
1010
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
1111
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
1212
import {ProgramManager} from './webgpu/program-manager';
13-
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
13+
import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
1414

1515
interface CommandInfo {
1616
readonly kernelId: number;
@@ -94,11 +94,32 @@ const getProgramInfoUniqueKey =
9494
return key;
9595
};
9696

97+
class AdapterInfoImpl implements AdapterInfo {
98+
readonly architecture?: string;
99+
readonly vendor?: string;
100+
101+
constructor(adapterInfo: GPUAdapterInfo) {
102+
if (adapterInfo) {
103+
this.architecture = adapterInfo.architecture;
104+
this.vendor = adapterInfo.vendor;
105+
}
106+
}
107+
108+
isArchitecture(architecture: GpuArchitecture): boolean {
109+
return this.architecture === architecture;
110+
}
111+
112+
isVendor(vendor: GpuVendor): boolean {
113+
return this.vendor === vendor;
114+
}
115+
}
116+
97117
/**
98118
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
99119
* the first parameter so that it is stored for future use.
100120
*/
101121
export class WebGpuBackend {
122+
adapterInfo: AdapterInfoImpl;
102123
device: GPUDevice;
103124
/**
104125
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
@@ -212,6 +233,7 @@ export class WebGpuBackend {
212233
}
213234

214235
this.device = await adapter.requestDevice(deviceDescriptor);
236+
this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo());
215237
this.gpuDataManager = createGpuDataManager(this);
216238
this.programManager = new ProgramManager(this);
217239
this.kernels = new Map();
@@ -231,6 +253,7 @@ export class WebGpuBackend {
231253
};
232254

233255
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
256+
Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter});
234257

235258
// init queryType, which is necessary for InferenceSession.create
236259
this.setQueryType();

js/web/lib/wasm/jsep/init.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu';
1010
import {LOG_DEBUG} from './log';
1111
import {TensorView} from './tensor-view';
1212
import {ShapeUtil} from './util';
13-
import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
13+
import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
1414

1515
/* eslint-disable no-bitwise */
1616

@@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView {
5454
}
5555

5656
class ComputeContextImpl implements ComputeContext {
57+
readonly adapterInfo: AdapterInfo;
5758
readonly opKernelContext: number;
5859
readonly inputs: readonly TensorView[];
5960
readonly outputCount: number;
@@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext {
6667
private customDataOffset = 0;
6768
private customDataSize = 0;
6869
constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
70+
this.adapterInfo = backend.adapterInfo;
6971
const heapU32 = module.HEAPU32;
7072

7173
// extract context data

js/web/lib/wasm/jsep/webgpu/ops/conv.ts

+4-3
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
148148
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
149149
const isChannelsLast = attributes.format === 'NHWC';
150150
if (attributes.group !== 1) {
151-
// Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases:
151+
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
152+
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
152153
// [webgpu]Conv - conv - vectorize group - B
153154
// [webgpu]Conv - conv - vectorize group - D
154-
const disableGroupedConvVectorize = true;
155-
if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
155+
const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
156+
if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
156157
inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
157158
const outputShape = calculateOutputShape(
158159
inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,

0 commit comments

Comments
 (0)