Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

Commit 7d315c4

Browse files
committed
Add Websocket streaming support to base
1 parent 84d1284 commit 7d315c4

File tree

4 files changed

+358
-0
lines changed

4 files changed

+358
-0
lines changed

stream/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2017 The Kubernetes Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from stream import stream

stream/stream.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
# not use this file except in compliance with the License. You may obtain
3+
# a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
# License for the specific language governing permissions and limitations
11+
# under the License.
12+
13+
from . import ws_client
14+
15+
16+
def stream(func, *args, **kwargs):
17+
"""Stream given API call using websocket"""
18+
19+
def _intercept_request_call(
20+
method, url, query_params=None, headers=None,
21+
post_params=None, body=None, _preload_content=True,
22+
_request_timeout=None):
23+
#pylint: disable=unused-argument
24+
# Not all of the arguments in request call is necessary for
25+
# a websocket call, but we need to define them so we can
26+
# intercept the call.
27+
28+
# old generated code's api client has config. new ones has
29+
# configuration
30+
try:
31+
config = func.__self__.api_client.configuration
32+
except AttributeError:
33+
config = func.__self__.api_client.config
34+
35+
return ws_client.websocket_call(config,
36+
url,
37+
query_params=query_params,
38+
_request_timeout=_request_timeout,
39+
_preload_content=_preload_content,
40+
headers=headers)
41+
42+
prev_request = func.__self__.api_client.request
43+
try:
44+
func.__self__.api_client.request = _intercept_request_call
45+
return func(*args, **kwargs)
46+
finally:
47+
func.__self__.api_client.request = prev_request

stream/ws_client.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
2+
# not use this file except in compliance with the License. You may obtain
3+
# a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10+
# License for the specific language governing permissions and limitations
11+
# under the License.
12+
13+
from kubernetes.client.rest import ApiException
14+
15+
import select
16+
import certifi
17+
import time
18+
import collections
19+
from websocket import WebSocket, ABNF, enableTrace
20+
import six
21+
import ssl
22+
from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
23+
24+
STDIN_CHANNEL = 0
25+
STDOUT_CHANNEL = 1
26+
STDERR_CHANNEL = 2
27+
ERROR_CHANNEL = 3
28+
RESIZE_CHANNEL = 4
29+
30+
31+
class WSClient:
32+
def __init__(self, configuration, url, headers):
33+
"""A websocket client with support for channels.
34+
35+
Exec command uses different channels for different streams. for
36+
example, 0 is stdin, 1 is stdout and 2 is stderr. Some other API calls
37+
like port forwarding can forward different pods' streams to different
38+
channels.
39+
"""
40+
enableTrace(False)
41+
header = []
42+
self._connected = False
43+
self._channels = {}
44+
self._all = ""
45+
46+
# We just need to pass the Authorization, ignore all the other
47+
# http headers we get from the generated code
48+
if headers and 'authorization' in headers:
49+
header.append("authorization: %s" % headers['authorization'])
50+
51+
if configuration.ws_streaming_protocol:
52+
header.append("Sec-WebSocket-Protocol: %s" %
53+
configuration.ws_streaming_protocol)
54+
55+
if url.startswith('wss://') and configuration.verify_ssl:
56+
ssl_opts = {
57+
'cert_reqs': ssl.CERT_REQUIRED,
58+
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
59+
}
60+
if configuration.assert_hostname is not None:
61+
ssl_opts['check_hostname'] = configuration.assert_hostname
62+
else:
63+
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
64+
65+
if configuration.cert_file:
66+
ssl_opts['certfile'] = configuration.cert_file
67+
if configuration.key_file:
68+
ssl_opts['keyfile'] = configuration.key_file
69+
70+
self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
71+
self.sock.connect(url, header=header)
72+
self._connected = True
73+
74+
def peek_channel(self, channel, timeout=0):
75+
"""Peek a channel and return part of the input,
76+
empty string otherwise."""
77+
self.update(timeout=timeout)
78+
if channel in self._channels:
79+
return self._channels[channel]
80+
return ""
81+
82+
def read_channel(self, channel, timeout=0):
83+
"""Read data from a channel."""
84+
if channel not in self._channels:
85+
ret = self.peek_channel(channel, timeout)
86+
else:
87+
ret = self._channels[channel]
88+
if channel in self._channels:
89+
del self._channels[channel]
90+
return ret
91+
92+
def readline_channel(self, channel, timeout=None):
93+
"""Read a line from a channel."""
94+
if timeout is None:
95+
timeout = float("inf")
96+
start = time.time()
97+
while self.is_open() and time.time() - start < timeout:
98+
if channel in self._channels:
99+
data = self._channels[channel]
100+
if "\n" in data:
101+
index = data.find("\n")
102+
ret = data[:index]
103+
data = data[index+1:]
104+
if data:
105+
self._channels[channel] = data
106+
else:
107+
del self._channels[channel]
108+
return ret
109+
self.update(timeout=(timeout - time.time() + start))
110+
111+
def write_channel(self, channel, data):
112+
"""Write data to a channel."""
113+
self.sock.send(chr(channel) + data)
114+
115+
def peek_stdout(self, timeout=0):
116+
"""Same as peek_channel with channel=1."""
117+
return self.peek_channel(STDOUT_CHANNEL, timeout=timeout)
118+
119+
def read_stdout(self, timeout=None):
120+
"""Same as read_channel with channel=1."""
121+
return self.read_channel(STDOUT_CHANNEL, timeout=timeout)
122+
123+
def readline_stdout(self, timeout=None):
124+
"""Same as readline_channel with channel=1."""
125+
return self.readline_channel(STDOUT_CHANNEL, timeout=timeout)
126+
127+
def peek_stderr(self, timeout=0):
128+
"""Same as peek_channel with channel=2."""
129+
return self.peek_channel(STDERR_CHANNEL, timeout=timeout)
130+
131+
def read_stderr(self, timeout=None):
132+
"""Same as read_channel with channel=2."""
133+
return self.read_channel(STDERR_CHANNEL, timeout=timeout)
134+
135+
def readline_stderr(self, timeout=None):
136+
"""Same as readline_channel with channel=2."""
137+
return self.readline_channel(STDERR_CHANNEL, timeout=timeout)
138+
139+
def read_all(self):
140+
"""Return buffered data received on stdout and stderr channels.
141+
This is useful for non-interactive call where a set of command passed
142+
to the API call and their result is needed after the call is concluded.
143+
Should be called after run_forever() or update()
144+
145+
TODO: Maybe we can process this and return a more meaningful map with
146+
channels mapped for each input.
147+
"""
148+
out = self._all
149+
self._all = ""
150+
self._channels = {}
151+
return out
152+
153+
def is_open(self):
154+
"""True if the connection is still alive."""
155+
return self._connected
156+
157+
def write_stdin(self, data):
158+
"""The same as write_channel with channel=0."""
159+
self.write_channel(STDIN_CHANNEL, data)
160+
161+
def update(self, timeout=0):
162+
"""Update channel buffers with at most one complete frame of input."""
163+
if not self.is_open():
164+
return
165+
if not self.sock.connected:
166+
self._connected = False
167+
return
168+
r, _, _ = select.select(
169+
(self.sock.sock, ), (), (), timeout)
170+
if r:
171+
op_code, frame = self.sock.recv_data_frame(True)
172+
if op_code == ABNF.OPCODE_CLOSE:
173+
self._connected = False
174+
return
175+
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
176+
data = frame.data
177+
if six.PY3:
178+
data = data.decode("utf-8")
179+
if len(data) > 1:
180+
channel = ord(data[0])
181+
data = data[1:]
182+
if data:
183+
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
184+
# keeping all messages in the order they received for
185+
# non-blocking call.
186+
self._all += data
187+
if channel not in self._channels:
188+
self._channels[channel] = data
189+
else:
190+
self._channels[channel] += data
191+
192+
def run_forever(self, timeout=None):
193+
"""Wait till connection is closed or timeout reached. Buffer any input
194+
received during this time."""
195+
if timeout:
196+
start = time.time()
197+
while self.is_open() and time.time() - start < timeout:
198+
self.update(timeout=(timeout - time.time() + start))
199+
else:
200+
while self.is_open():
201+
self.update(timeout=None)
202+
203+
def close(self, **kwargs):
204+
"""
205+
close websocket connection.
206+
"""
207+
self._connected = False
208+
if self.sock:
209+
self.sock.close(**kwargs)
210+
211+
212+
WSResponse = collections.namedtuple('WSResponse', ['data'])
213+
214+
215+
def get_websocket_url(url):
216+
parsed_url = urlparse(url)
217+
parts = list(parsed_url)
218+
if parsed_url.scheme == 'http':
219+
parts[0] = 'ws'
220+
elif parsed_url.scheme == 'https':
221+
parts[0] = 'wss'
222+
return urlunparse(parts)
223+
224+
225+
def websocket_call(configuration, url, query_params, _request_timeout,
226+
_preload_content, headers):
227+
"""An internal function to be called in api-client when a websocket
228+
connection is required."""
229+
230+
# Extract the command from the list of tuples
231+
commands = None
232+
for key, value in query_params:
233+
if key == 'command':
234+
commands = value
235+
break
236+
237+
# drop command from query_params as we will be processing it separately
238+
query_params = [(key, value) for key, value in query_params if
239+
key != 'command']
240+
241+
# if we still have query params then encode them
242+
if query_params:
243+
url += '?' + urlencode(query_params)
244+
245+
# tack on the actual command to execute at the end
246+
if isinstance(commands, list):
247+
for command in commands:
248+
url += "&command=%s&" % quote_plus(command)
249+
elif commands is not None:
250+
url += '&command=' + quote_plus(commands)
251+
252+
try:
253+
client = WSClient(configuration, get_websocket_url(url), headers)
254+
if not _preload_content:
255+
return client
256+
client.run_forever(timeout=_request_timeout)
257+
return WSResponse('%s' % ''.join(client.read_all()))
258+
except (Exception, KeyboardInterrupt, SystemExit) as e:
259+
raise ApiException(status=0, reason=str(e))

stream/ws_client_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2017 The Kubernetes Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from .ws_client import get_websocket_url
18+
19+
20+
class WSClientTest(unittest.TestCase):
21+
22+
def test_websocket_client(self):
23+
for url, ws_url in [
24+
('http://localhost/api', 'ws://localhost/api'),
25+
('https://localhost/api', 'wss://localhost/api'),
26+
('https://domain.com/api', 'wss://domain.com/api'),
27+
('https://api.domain.com/api', 'wss://api.domain.com/api'),
28+
('http://api.domain.com', 'ws://api.domain.com'),
29+
('https://api.domain.com', 'wss://api.domain.com'),
30+
('http://api.domain.com/', 'ws://api.domain.com/'),
31+
('https://api.domain.com/', 'wss://api.domain.com/'),
32+
]:
33+
self.assertEqual(get_websocket_url(url), ws_url)
34+
35+
36+
if __name__ == '__main__':
37+
unittest.main()

0 commit comments

Comments
 (0)