1
1
from abc import ABC , abstractmethod
2
2
from typing import List
3
3
import uuid
4
+ import struct
5
+
6
+ import logging
7
+
8
+ logger = logging .getLogger (__name__ )
4
9
5
10
6
11
class SideChannel (ABC ):
@@ -16,15 +21,15 @@ def __init__(self, channel_id):
16
21
self ._channel_id : uuid .UUID = channel_id
17
22
self .message_queue : List [bytearray ] = []
18
23
19
- def queue_message_to_send (self , data : bytearray ) -> None :
24
+ def queue_message_to_send (self , msg : "OutgoingMessage" ) -> None :
20
25
"""
21
26
Queues a message to be sent by the environment at the next call to
22
27
step.
23
28
"""
24
- self .message_queue .append (data )
29
+ self .message_queue .append (msg . buffer )
25
30
26
31
@abstractmethod
27
- def on_message_received (self , data : bytes ) -> None :
32
+ def on_message_received (self , msg : "IncomingMessage" ) -> None :
28
33
"""
29
34
Is called by the environment to the side channel. Can be called
30
35
multiple times per step if multiple messages are meant for that
@@ -39,3 +44,51 @@ def channel_id(self) -> uuid.UUID:
39
44
processed in the environment.
40
45
"""
41
46
return self ._channel_id
47
+
48
+
49
+ class OutgoingMessage :
50
+ def __init__ (self ):
51
+ self .buffer = bytearray ()
52
+
53
+ def write_int32 (self , i : int ) -> None :
54
+ self .buffer += struct .pack ("<i" , i )
55
+
56
+ def write_float32 (self , f : float ) -> None :
57
+ self .buffer += struct .pack ("<f" , f )
58
+
59
+ def write_string (self , s : str ) -> None :
60
+ encoded_key = s .encode ("ascii" )
61
+ self .write_int32 (len (encoded_key ))
62
+ self .buffer += encoded_key
63
+
64
+ def set_raw_bytes (self , buffer : bytearray ) -> None :
65
+ if self .buffer :
66
+ logger .warning (
67
+ "Called set_raw_bytes but the message already has been written to. This will overwrite data."
68
+ )
69
+ self .buffer = bytearray (buffer )
70
+
71
+
72
+ class IncomingMessage :
73
+ def __init__ (self , buffer : bytes , offset : int = 0 ):
74
+ self .buffer = buffer
75
+ self .offset = offset
76
+
77
+ def read_int32 (self ) -> int :
78
+ val = struct .unpack_from ("<i" , self .buffer , self .offset )[0 ]
79
+ self .offset += 4
80
+ return val
81
+
82
+ def read_float32 (self ) -> float :
83
+ val = struct .unpack_from ("<f" , self .buffer , self .offset )[0 ]
84
+ self .offset += 4
85
+ return val
86
+
87
+ def read_string (self ) -> str :
88
+ encoded_str_len = self .read_int32 ()
89
+ val = self .buffer [self .offset : self .offset + encoded_str_len ].decode ("ascii" )
90
+ self .offset += encoded_str_len
91
+ return val
92
+
93
+ def get_raw_bytes (self ) -> bytes :
94
+ return bytearray (self .buffer )
0 commit comments