Skip to content

Commit edc55a0

Browse files
committed
[mypy] annotates vars for other/lru_cache
1 parent 04c8cba commit edc55a0

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

other/lru_cache.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ class DoubleLinkedListNode:
1212
Node: key: 1, val: 1, has next: False, has prev: False
1313
"""
1414

15-
def __init__(self, key: int, val: int):
15+
def __init__(self, key: int | None, val: int | None):
1616
self.key = key
1717
self.val = val
18-
self.next = None
19-
self.prev = None
18+
self.next: DoubleLinkedListNode | None = None
19+
self.prev: DoubleLinkedListNode | None = None
2020

2121
def __repr__(self) -> str:
2222
return "Node: key: {}, val: {}, has next: {}, has prev: {}".format(
@@ -91,7 +91,7 @@ class DoubleLinkedList:
9191
9292
"""
9393

94-
def __init__(self):
94+
def __init__(self) -> None:
9595
self.head = DoubleLinkedListNode(None, None)
9696
self.rear = DoubleLinkedListNode(None, None)
9797
self.head.next, self.rear.prev = self.rear, self.head
@@ -111,6 +111,10 @@ def add(self, node: DoubleLinkedListNode) -> None:
111111
"""
112112

113113
previous = self.rear.prev
114+
115+
# All nodes other than self.head are guaranteed to have non-None previous
116+
assert previous is not None
117+
114118
previous.next = node
115119
node.prev = previous
116120
self.rear.prev = node
@@ -136,6 +140,7 @@ def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
136140
return node
137141

138142

143+
# class LRUCache(Generic[T]):
139144
class LRUCache:
140145
"""
141146
LRU Cache to store a given capacity of data. Can be used as a stand-alone object
@@ -201,15 +206,17 @@ class LRUCache:
201206
"""
202207

203208
# class variable to map the decorator functions to their respective instance
204-
decorator_function_to_instance_map = {}
209+
decorator_function_to_instance_map: dict[Callable, LRUCache] = {}
205210

206211
def __init__(self, capacity: int):
207212
self.list = DoubleLinkedList()
208213
self.capacity = capacity
209214
self.num_keys = 0
210215
self.hits = 0
211216
self.miss = 0
212-
self.cache = {}
217+
# self.cache: dict[int, int] = {}
218+
# self.cache: dict[int, T] = {}
219+
self.cache: dict[int, DoubleLinkedListNode] = {}
213220

214221
def __repr__(self) -> str:
215222
"""
@@ -245,8 +252,14 @@ def get(self, key: int) -> int | None:
245252

246253
if key in self.cache:
247254
self.hits += 1
248-
self.list.add(self.list.remove(self.cache[key]))
249-
return self.cache[key].val
255+
value_node = self.cache[key]
256+
node = self.list.remove(self.cache[key])
257+
assert node == value_node
258+
259+
# node is guaranteed not None because it is in self.cache
260+
assert node is not None
261+
self.list.add(node)
262+
return node.val
250263
self.miss += 1
251264
return None
252265

@@ -257,16 +270,25 @@ def set(self, key: int, value: int) -> None:
257270

258271
if key not in self.cache:
259272
if self.num_keys >= self.capacity:
260-
key_to_delete = self.list.head.next.key
261-
self.list.remove(self.cache[key_to_delete])
262-
del self.cache[key_to_delete]
273+
# delete first node (oldest) when over capacity
274+
first_node = self.list.head.next
275+
276+
# guaranteed to have a non-None first node when num_keys > 0
277+
# explain to type checker via assertions
278+
assert first_node is not None
279+
assert first_node.key is not None
280+
assert self.list.remove(first_node) is not None # node guaranteed to be in list assert node.key is not None
281+
282+
del self.cache[first_node.key]
263283
self.num_keys -= 1
264284
self.cache[key] = DoubleLinkedListNode(key, value)
265285
self.list.add(self.cache[key])
266286
self.num_keys += 1
267287

268288
else:
289+
# bump node to the end of the list, update value
269290
node = self.list.remove(self.cache[key])
291+
assert node is not None # node guaranteed to be in list
270292
node.val = value
271293
self.list.add(node)
272294

@@ -289,10 +311,10 @@ def cache_decorator_wrapper(*args, **kwargs):
289311
)
290312
return result
291313

292-
def cache_info():
314+
def cache_info() -> LRUCache:
293315
return LRUCache.decorator_function_to_instance_map[func]
294316

295-
cache_decorator_wrapper.cache_info = cache_info
317+
setattr(cache_decorator_wrapper, "cache_info", cache_info)
296318

297319
return cache_decorator_wrapper
298320

0 commit comments

Comments
 (0)