15
15
using namespace ::testing;
16
16
using torch::executor::Error;
17
17
using torch::executor::KernelRuntimeContext;
18
+ using torch::executor::MemoryAllocator;
19
+ using torch::executor::Result;
18
20
19
21
class KernelRuntimeContextTest : public ::testing::Test {
20
22
public:
@@ -23,6 +25,17 @@ class KernelRuntimeContextTest : public ::testing::Test {
23
25
}
24
26
};
25
27
28
+ class TestMemoryAllocator : public MemoryAllocator {
29
+ public:
30
+ TestMemoryAllocator (uint32_t size, uint8_t * base_address)
31
+ : MemoryAllocator(size, base_address), last_seen_alignment(0 ) {}
32
+ void * allocate (size_t size, size_t alignment) override {
33
+ last_seen_alignment = alignment;
34
+ return MemoryAllocator::allocate (size, alignment);
35
+ }
36
+ size_t last_seen_alignment;
37
+ };
38
+
26
39
TEST_F (KernelRuntimeContextTest, FailureStateDefaultsToOk) {
27
40
KernelRuntimeContext context;
28
41
@@ -47,3 +60,43 @@ TEST_F(KernelRuntimeContextTest, FailureStateReflectsFailure) {
47
60
context.fail (Error::Ok);
48
61
EXPECT_EQ (context.failure_state (), Error::Ok);
49
62
}
63
+
64
+ TEST_F (KernelRuntimeContextTest, FailureNoMemoryAllocatorProvided) {
65
+ KernelRuntimeContext context;
66
+ Result<void *> allocated_memory = context.allocate_temp (4 );
67
+ EXPECT_EQ (allocated_memory.error (), Error::NotFound);
68
+ }
69
+
70
+ TEST_F (KernelRuntimeContextTest, SuccessfulMemoryAllocation) {
71
+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
72
+ auto temp_memory_allocator_pool =
73
+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
74
+ MemoryAllocator temp_allocator (
75
+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
76
+ KernelRuntimeContext context (nullptr , &temp_allocator);
77
+ Result<void *> allocated_memory = context.allocate_temp (4 );
78
+ EXPECT_EQ (allocated_memory.ok (), true );
79
+ }
80
+
81
+ TEST_F (KernelRuntimeContextTest, FailureMemoryAllocationInsufficientSpace) {
82
+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
83
+ auto temp_memory_allocator_pool =
84
+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
85
+ MemoryAllocator temp_allocator (
86
+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
87
+ KernelRuntimeContext context (nullptr , &temp_allocator);
88
+ Result<void *> allocated_memory = context.allocate_temp (8 );
89
+ EXPECT_EQ (allocated_memory.error (), Error::MemoryAllocationFailed);
90
+ }
91
+
92
+ TEST_F (KernelRuntimeContextTest, MemoryAllocatorAlignmentPassed) {
93
+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
94
+ auto temp_memory_allocator_pool =
95
+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
96
+ TestMemoryAllocator temp_allocator (
97
+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
98
+ KernelRuntimeContext context (nullptr , &temp_allocator);
99
+ Result<void *> allocated_memory = context.allocate_temp (4 , 2 );
100
+ EXPECT_EQ (allocated_memory.ok (), true );
101
+ EXPECT_EQ (temp_allocator.last_seen_alignment , 2 );
102
+ }
0 commit comments