Skip to content

feat: supporting managed identity #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import json
from typing import Union

from azure.storage.blob import BlobClient as BlobClientSdk
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import validate_connection_string
from .utils import get_connection_string, using_managed_identity


class BlobClient(SdkType):
def __init__(self, *, data: Union[bytes, Datum]) -> None:
# model_binding_data properties
self._data = data
self._using_managed_identity = False
self._version = None
self._source = None
self._content_type = None
Expand All @@ -24,17 +25,32 @@ def __init__(self, *, data: Union[bytes, Datum]) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = validate_connection_string(content_json["Connection"])
self._containerName = content_json["ContainerName"]
self._blobName = content_json["BlobName"]
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")

# Returns a BlobClient
def get_sdk_type(self):
"""
When using Managed Identity, the only way to create a BlobClient is
through a BlobServiceClient. There are two ways to create a
BlobServiceClient:
1. Through the constructor: this is the only option when using Managed Identity
2. Through from_connection_string: this is the only option when not using Managed Identity

We track if Managed Identity is being used through a flag.
"""
if self._data:
return BlobClientSdk.from_connection_string(
conn_str=self._connection,
container_name=self._containerName,
blob_name=self._blobName,
blob_service_client = (
BlobServiceClient(account_url=self._connection)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
return blob_service_client.get_blob_client(
container=self._containerName,
blob=self._blobName,
)
else:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import json
from typing import Union

from azure.storage.blob import ContainerClient as ContainerClientSdk
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import validate_connection_string
from .utils import get_connection_string, using_managed_identity


class ContainerClient(SdkType):
def __init__(self, *, data: Union[bytes, Datum]) -> None:
# model_binding_data properties
self._data = data
self._using_managed_identity = False
self._version = ""
self._source = ""
self._content_type = ""
Expand All @@ -24,15 +25,23 @@ def __init__(self, *, data: Union[bytes, Datum]) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = validate_connection_string(content_json["Connection"])
self._containerName = content_json["ContainerName"]
self._blobName = content_json["BlobName"]
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")

# Returns a ContainerClient
def get_sdk_type(self):
if self._data:
return ContainerClientSdk.from_connection_string(
conn_str=self._connection, container_name=self._containerName
blob_service_client = (
BlobServiceClient(account_url=self._connection)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
return blob_service_client.get_container_client(
container=self._containerName
)
else:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import json
from typing import Union

from azure.storage.blob import BlobClient as BlobClientSdk
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import validate_connection_string
from .utils import get_connection_string, using_managed_identity


class StorageStreamDownloader(SdkType):
def __init__(self, *, data: Union[bytes, Datum]) -> None:
# model_binding_data properties
self._data = data or {}
self._data = data
self._using_managed_identity = False
self._version = ""
self._source = ""
self._content_type = ""
Expand All @@ -24,20 +25,25 @@ def __init__(self, *, data: Union[bytes, Datum]) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = validate_connection_string(content_json["Connection"])
self._containerName = content_json["ContainerName"]
self._blobName = content_json["BlobName"]
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")

# Returns a StorageStreamDownloader
def get_sdk_type(self):
if self._data:
# Create BlobClient
blob_client = BlobClientSdk.from_connection_string(
conn_str=self._connection,
container_name=self._containerName,
blob_name=self._blobName,
blob_service_client = (
BlobServiceClient(account_url=self._connection)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
# download_blob() returns a StorageStreamDownloader object
return blob_client.download_blob()
return blob_service_client.get_blob_client(
container=self._containerName,
blob=self._blobName,
).download_blob()
else:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,47 @@
import os


def validate_connection_string(connection_string: str) -> str:
def get_connection_string(connection_string: str) -> str:
"""
Validates the connection string. If the connection string is
Validates and returns the connection string. If the connection string is
not an App Setting, an error will be thrown.

When using managed identity, the connection string variable name is formatted like so:
Input: <CONNECTION_NAME_PREFIX>__serviceUri
Trigger: <CONNECTION_NAME_PREFIX>__blobServiceUri
The variable received will be <CONNECTION_NAME_PREFIX>. Therefore, we need to append
the suffix to obtain the storage URI and create the client.

There are four cases:
1. Not using managed identity: the environment variable exists as is
2. Using managed identity for blob input: __serviceUri must be appended
3. Using managed identity for blob trigger: __blobServiceUri must be appended
4. None of these cases existed, so the connection variable is invalid.
"""
if connection_string == None:
if connection_string is None:
raise ValueError(
"Storage account connection string cannot be none. "
"Storage account connection string cannot be None. "
"Please provide a connection string."
)
elif not os.getenv(connection_string):
elif connection_string in os.environ:
return os.getenv(connection_string)
elif connection_string + "__serviceUri" in os.environ:
return os.getenv(connection_string + "__serviceUri")
elif connection_string + "__blobServiceUri" in os.environ:
return os.getenv(connection_string + "__blobServiceUri")
else:
raise ValueError(
f"Storage account connection string {connection_string} does not exist. "
f"Please make sure that it is a defined App Setting."
)
return os.getenv(connection_string)


def using_managed_identity(connection_name: str) -> bool:
"""
To determine if managed identity is being used, we check if the provided
connection string has either of the two suffixes:
__serviceUri or __blobServiceUri.
"""
return (os.getenv(connection_name + "__serviceUri") is not None) or (
os.getenv(connection_name + "__blobServiceUri") is not None
)
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,63 @@ def test_none_input_populated(self):
)
self.assertEqual(
e.exception.args[0],
"Storage account connection string cannot be none. Please provide a connection string.",
"Storage account connection string cannot be None. Please provide a connection string.",
)

def test_input_populated_managed_identity_input(self):
content = {
"Connection": "input",
"ContainerName": "test-blob",
"BlobName": "text.txt",
}

sample_mbd = MockMBD(
version="1.0",
source="AzureStorageBlobs",
content_type="application/json",
content=json.dumps(content),
)

datum: Datum = Datum(value=sample_mbd, type="model_binding_data")
result: BlobClient = BlobClientConverter.decode(
data=datum, trigger_metadata=None, pytype=BlobClient
)

self.assertIsNotNone(result)
self.assertIsInstance(result, BlobClientSdk)

sdk_result = BlobClient(data=datum.value).get_sdk_type()

self.assertIsNotNone(sdk_result)
self.assertIsInstance(sdk_result, BlobClientSdk)

def test_input_populated_managed_identity_trigger(self):
content = {
"Connection": "trigger",
"ContainerName": "test-blob",
"BlobName": "text.txt",
}

sample_mbd = MockMBD(
version="1.0",
source="AzureStorageBlobs",
content_type="application/json",
content=json.dumps(content),
)

datum: Datum = Datum(value=sample_mbd, type="model_binding_data")
result: BlobClient = BlobClientConverter.decode(
data=datum, trigger_metadata=None, pytype=BlobClient
)

self.assertIsNotNone(result)
self.assertIsInstance(result, BlobClientSdk)

sdk_result = BlobClient(data=datum.value).get_sdk_type()

self.assertIsNotNone(sdk_result)
self.assertIsInstance(sdk_result, BlobClientSdk)

def test_input_invalid_pytype(self):
content = {
"Connection": "AzureWebJobsStorage",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,63 @@ def test_none_input_populated(self):
)
self.assertEqual(
e.exception.args[0],
"Storage account connection string cannot be none. Please provide a connection string.",
"Storage account connection string cannot be None. Please provide a connection string.",
)

def test_input_populated_managed_identity_input(self):
content = {
"Connection": "input",
"ContainerName": "test-blob",
"BlobName": "text.txt",
}

sample_mbd = MockMBD(
version="1.0",
source="AzureStorageBlobs",
content_type="application/json",
content=json.dumps(content),
)

datum: Datum = Datum(value=sample_mbd, type="model_binding_data")
result: ContainerClient = BlobClientConverter.decode(
data=datum, trigger_metadata=None, pytype=ContainerClient
)

self.assertIsNotNone(result)
self.assertIsInstance(result, ContainerClientSdk)

sdk_result = ContainerClient(data=datum.value).get_sdk_type()

self.assertIsNotNone(sdk_result)
self.assertIsInstance(sdk_result, ContainerClientSdk)

def test_input_populated_managed_identity_trigger(self):
content = {
"Connection": "trigger",
"ContainerName": "test-blob",
"BlobName": "text.txt",
}

sample_mbd = MockMBD(
version="1.0",
source="AzureStorageBlobs",
content_type="application/json",
content=json.dumps(content),
)

datum: Datum = Datum(value=sample_mbd, type="model_binding_data")
result: ContainerClient = BlobClientConverter.decode(
data=datum, trigger_metadata=None, pytype=ContainerClient
)

self.assertIsNotNone(result)
self.assertIsInstance(result, ContainerClientSdk)

sdk_result = ContainerClient(data=datum.value).get_sdk_type()

self.assertIsNotNone(sdk_result)
self.assertIsInstance(sdk_result, ContainerClientSdk)

def test_input_invalid_pytype(self):
content = {
"Connection": "AzureWebJobsStorage",
Expand Down
Loading
Loading