19
19
20
20
# NOTICE: The design adapted from:
21
21
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
22
+ """Auto Accelerator Module."""
22
23
23
24
24
25
# To keep it simply, only add the APIs we need.
40
41
41
42
42
43
class AcceleratorRegistry :
44
+ """Accelerator Registry."""
45
+
43
46
registered_accelerators = {}
44
47
45
48
@classmethod
@@ -94,171 +97,253 @@ class CUDA_Accelerator:
94
97
name: the accelerator name.
95
98
priority: the priority of the accelerator. A larger number indicates a higher priority,
96
99
"""
97
-
98
100
return accelerator_registry .register_accelerator_impl (name = name , priority = priority )
99
101
100
102
101
103
class Auto_Accelerator (ABC ): # pragma: no cover
104
+ """Auto Accelerator Base class."""
105
+
102
106
@classmethod
103
107
@abstractmethod
104
108
def is_available (cls ) -> bool :
109
+ """Check if the accelerator is available."""
105
110
pass
106
111
107
112
@abstractmethod
108
113
def name (self ) -> str :
114
+ """Get the accelerator name."""
109
115
pass
110
116
111
117
@abstractmethod
112
118
def device_name (self , device_indx ) -> str :
119
+ """Get the device name."""
113
120
pass
114
121
115
122
@abstractmethod
116
123
def set_device (self , device_index ):
124
+ """Set the device."""
117
125
pass
118
126
119
127
@abstractmethod
120
128
def current_device (self ):
129
+ """Get the current device."""
121
130
pass
122
131
123
132
@abstractmethod
124
133
def current_device_name (self ):
134
+ """Get the current device name."""
125
135
pass
126
136
127
137
@abstractmethod
128
138
def device (self , device_index = None ):
139
+ """Get the device."""
129
140
pass
130
141
131
142
@abstractmethod
132
143
def empty_cache (self ):
144
+ """Empty the cache."""
133
145
pass
134
146
135
147
@abstractmethod
136
148
def synchronize (self ):
149
+ """Synchronize the accelerator."""
137
150
pass
138
151
139
152
def mark_step (self ):
153
+ """Trigger graph to run."""
140
154
pass
141
155
142
156
143
157
@register_accelerator (name = "cpu" , priority = PRIORITY_CPU )
144
158
class CPU_Accelerator (Auto_Accelerator ):
159
+ """CPU Accelerator."""
160
+
145
161
def __init__ (self ) -> None :
162
+ """Initialize CPU Accelerator."""
146
163
self ._name = "cpu"
147
164
148
165
def name (self ) -> str :
166
+ """Get the accelerator name."""
149
167
return self ._name
150
168
151
169
@classmethod
152
170
def is_available (cls ) -> bool :
171
+ """Always return True."""
153
172
return True
154
173
155
174
def device_name (self , device_indx ) -> str :
175
+ """Get the device name."""
156
176
return "cpu"
157
177
158
178
def set_device (self , device_index ):
179
+ """Do nothing."""
159
180
pass
160
181
161
182
def current_device (self ):
183
+ """Get the current device."""
162
184
return "cpu"
163
185
164
186
def current_device_name (self ):
187
+ """Get the current device name."""
165
188
return "cpu"
166
189
167
190
def device (self , device_index = None ):
191
+ """Do nothing."""
168
192
pass
169
193
170
194
def empty_cache (self ):
195
+ """Do nothing."""
171
196
pass
172
197
173
198
def synchronize (self ):
199
+ """Do nothing."""
174
200
pass
175
201
176
202
177
203
@register_accelerator (name = "cuda" , priority = PRIORITY_CUDA )
178
204
class CUDA_Accelerator (Auto_Accelerator ): # pragma: no cover
205
+ """CUDA Accelerator."""
206
+
179
207
def __init__ (self ) -> None :
208
+ """Initialize CUDA Accelerator."""
180
209
self ._name = "cuda"
181
210
182
211
def name (self ) -> str :
212
+ """Get the accelerator name."""
183
213
return self ._name
184
214
185
215
@classmethod
186
216
def is_available (cls ) -> bool :
217
+ """Check if the 'cuda' device is available."""
187
218
return torch .cuda .is_available ()
188
219
189
220
def device_name (self , device_indx ) -> str :
221
+ """Returns the name of the 'cuda' device with the given index."""
190
222
if device_indx is None :
191
223
return "cuda"
192
224
return f"cuda:{ device_indx } "
193
225
194
226
def synchronize (self ):
227
+ """Synchronizes the 'cuda' device."""
195
228
return torch .cuda .synchronize ()
196
229
197
230
def set_device (self , device_index ):
231
+ """Sets the current 'cuda' device to the one with the given index."""
198
232
return torch .cuda .set_device (device_index )
199
233
200
234
def current_device (self ):
235
+ """Returns the index of the current 'cuda' device."""
201
236
return torch .cuda .current_device ()
202
237
203
238
def current_device_name (self ):
239
+ """Returns the name of the current 'cuda' device."""
204
240
return "cuda:{}" .format (torch .cuda .current_device ())
205
241
206
242
def device (self , device_index = None ):
243
+ """Returns a torch.device object for the 'cuda' device with the given index."""
207
244
return torch .cuda .device (device_index )
208
245
209
246
def empty_cache (self ):
247
+ """Empties the cuda cache."""
210
248
return torch .cuda .empty_cache ()
211
249
212
250
213
251
@register_accelerator (name = "xpu" , priority = PRIORITY_XPU )
214
252
class XPU_Accelerator (Auto_Accelerator ): # pragma: no cover
253
+ """XPU Accelerator."""
254
+
215
255
def __init__ (self ) -> None :
256
+ """Initialize XPU Accelerator."""
216
257
self ._name = "xpu"
217
258
218
259
def name (self ) -> str :
260
+ """Get the accelerator name."""
219
261
return self ._name
220
262
221
263
@classmethod
222
264
def is_available (cls ) -> bool :
265
+ """Checks if the 'xpu' device is available.
266
+
267
+ Returns:
268
+ bool: True if the 'xpu' device is available, False otherwise.
269
+ """
223
270
if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
224
271
return True
225
272
else :
226
273
return False
227
274
228
275
def device_name (self , device_indx ) -> str :
276
+ """Returns the name of the 'xpu' device with the given index.
277
+
278
+ Args:
279
+ device_indx (int): The index of the 'xpu' device.
280
+
281
+ Returns:
282
+ str: The name of the 'xpu' device.
283
+ """
229
284
if device_indx is None :
230
285
return "xpu"
231
286
return f"xpu:{ device_indx } "
232
287
233
288
def synchronize (self ):
289
+ """Synchronizes the 'xpu' device."""
234
290
return torch .xpu .synchronize ()
235
291
236
292
def set_device (self , device_index ):
293
+ """Sets the current 'xpu' device to the one with the given index.
294
+
295
+ Args:
296
+ device_index (int): The index of the 'xpu' device.
297
+ """
237
298
return torch .xpu .set_device (device_index )
238
299
239
300
def current_device (self ):
301
+ """Returns the index of the current 'xpu' device.
302
+
303
+ Returns:
304
+ int: The index of the current 'xpu' device.
305
+ """
240
306
return torch .xpu .current_device ()
241
307
242
308
def current_device_name (self ):
309
+ """Returns the name of the current 'xpu' device.
310
+
311
+ Returns:
312
+ str: The name of the current 'xpu' device.
313
+ """
243
314
return "xpu:{}" .format (torch .xpu .current_device ())
244
315
245
316
def device (self , device_index = None ):
317
+ """Returns a torch.device object for the 'xpu' device with the given index.
318
+
319
+ Args:
320
+ device_index (int, optional): The index of the 'xpu' device. Defaults to None.
321
+
322
+ Returns:
323
+ torch.device: The torch.device object for the 'xpu' device.
324
+ """
246
325
return torch .xpu .device (device_index )
247
326
248
327
def empty_cache (self ):
328
+ """Empties the xpu cache."""
249
329
return torch .xpu .empty_cache ()
250
330
251
331
252
332
@register_accelerator (name = "hpu" , priority = PRIORITY_HPU )
253
333
class HPU_Accelerator (Auto_Accelerator ): # pragma: no cover
334
+ """HPU Accelerator."""
335
+
254
336
def __init__ (self ) -> None :
337
+ """Initialize HPU Accelerator."""
255
338
self ._name = "hpu"
256
339
257
340
def name (self ) -> str :
341
+ """Get the accelerator name."""
258
342
return self ._name
259
343
260
344
@classmethod
261
345
def is_available (cls ) -> bool :
346
+ """Checks if the 'hpu' device is available."""
262
347
from .environ import is_hpex_available
263
348
264
349
if is_hpex_available ():
@@ -267,43 +352,54 @@ def is_available(cls) -> bool:
267
352
return False
268
353
269
354
def device_name (self , device_indx ) -> str :
355
+ """Returns the name of the 'hpu' device with the given index."""
270
356
if device_indx is None :
271
357
return "hpu"
272
358
return f"hpu:{ device_indx } "
273
359
274
360
def synchronize (self ):
361
+ """Synchronizes the 'hpu' device."""
275
362
return torch .hpu .synchronize ()
276
363
277
364
def set_device (self , device_index ):
365
+ """Sets the current 'hpu' device to the one with the given index."""
278
366
try :
279
367
torch .hpu .set_device (device_index )
280
368
except Exception as e :
281
369
logger .warning (e )
282
370
283
371
def current_device (self ):
372
+ """Returns the index of the current 'hpu' device."""
284
373
return torch .hpu .current_device ()
285
374
286
375
def current_device_name (self ):
376
+ """Returns the name of the current 'hpu' device."""
287
377
return "hpu:{}" .format (torch .hpu .current_device ())
288
378
289
379
def device (self , device_index = None ):
380
+ """Returns a torch.device object for the 'hpu' device with the given index."""
290
381
return torch .hpu .device (device_index )
291
382
292
383
def empty_cache (self ):
384
+ """Empties the hpu cache."""
293
385
try :
294
386
torch .hpu .empty_cache ()
295
387
except Exception as e :
296
388
logger .warning (e )
297
389
298
390
def mark_step (self ):
391
+ """Trigger graph to run."""
299
392
return htcore .mark_step ()
300
393
301
394
302
395
def auto_detect_accelerator (device_name = "auto" ) -> Auto_Accelerator :
303
- # Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...
304
- # The `FORCE_DEVICE` is case insensitive.
305
- # The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
306
- # TODO: refine the docs and logic later
396
+ """Automatically detects and selects the appropriate accelerator.
397
+
398
+ Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...
399
+ The `FORCE_DEVICE` is case insensitive.
400
+ The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
401
+ TODO: refine the docs and logic later
402
+ """
307
403
# 1. Get the device setting from environment variable `FORCE_DEVICE`.
308
404
FORCE_DEVICE = os .environ .get ("FORCE_DEVICE" , None )
309
405
if FORCE_DEVICE :
0 commit comments