|
37 | 37 | import time
|
38 | 38 | from threading import Lock, RLock, Thread, Event
|
39 | 39 | import uuid
|
| 40 | +import os |
| 41 | +import urllib.request |
| 42 | +import json |
40 | 43 |
|
41 | 44 | import weakref
|
42 | 45 | from weakref import WeakValueDictionary
|
@@ -1169,6 +1172,26 @@ def __init__(self,
|
1169 | 1172 |
|
1170 | 1173 | uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection)
|
1171 | 1174 | uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection)
|
| 1175 | + |
| 1176 | + # Check if we need to download the secure connect bundle |
| 1177 | + if 'db_id' in cloud and 'token' in cloud: |
| 1178 | + # download SCB if necessary |
| 1179 | + if 'secure_connect_bundle' not in cloud: |
| 1180 | + bundle_path = f'astra-secure-connect-{cloud["db_id"]}.zip' |
| 1181 | + if not os.path.exists(bundle_path): |
| 1182 | + print('Downloading Secure Cloud Bundle') |
| 1183 | + url = self._get_astra_bundle_url(cloud['db_id'], cloud['token']) |
| 1184 | + try: |
| 1185 | + with urllib.request.urlopen(url) as r: |
| 1186 | + with open(bundle_path, 'wb') as f: |
| 1187 | + f.write(r.read()) |
| 1188 | + except urllib.error.URLError as e: |
| 1189 | + raise Exception(f"Error downloading secure connect bundle: {str(e)}") |
| 1190 | + cloud['secure_connect_bundle'] = bundle_path |
| 1191 | + # Set up auth_provider if not provided |
| 1192 | + if auth_provider is None: |
| 1193 | + auth_provider = PlainTextAuthProvider('token', cloud['token']) |
| 1194 | + |
1172 | 1195 | cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet)
|
1173 | 1196 |
|
1174 | 1197 | ssl_context = cloud_config.ssl_context
|
@@ -2184,6 +2207,29 @@ def get_control_connection_host(self):
|
2184 | 2207 | endpoint = connection.endpoint if connection else None
|
2185 | 2208 | return self.metadata.get_host(endpoint) if endpoint else None
|
2186 | 2209 |
|
| 2210 | + @staticmethod |
| 2211 | + def _get_astra_bundle_url(db_id, token): |
| 2212 | + # set up the request |
| 2213 | + url = f"https://api.astra.datastax.com/v2/databases/{db_id}/secureBundleURL" |
| 2214 | + headers = { |
| 2215 | + "Authorization": f"Bearer {token}", |
| 2216 | + "Content-Type": "application/json" |
| 2217 | + } |
| 2218 | + |
| 2219 | + req = urllib.request.Request(url, method="POST", headers=headers, data=b"") |
| 2220 | + try: |
| 2221 | + with urllib.request.urlopen(req) as response: |
| 2222 | + response_data = json.loads(response.read().decode()) |
| 2223 | + # happy path |
| 2224 | + if 'downloadURL' in response_data: |
| 2225 | + return response_data['downloadURL'] |
| 2226 | + # handle errors |
| 2227 | + if 'errors' in response_data: |
| 2228 | + raise Exception(response_data['errors'][0]['message']) |
| 2229 | + raise Exception('Unknown error in ' + str(response_data)) |
| 2230 | + except urllib.error.URLError as e: |
| 2231 | + raise Exception(f"Error connecting to Astra API: {str(e)}") |
| 2232 | + |
2187 | 2233 | def refresh_schema_metadata(self, max_schema_agreement_wait=None):
|
2188 | 2234 | """
|
2189 | 2235 | Synchronously refresh all schema metadata.
|
|
0 commit comments