diff --git a/libraries/botbuilder-ai/tests/qna/test_qna.py b/libraries/botbuilder-ai/tests/qna/test_qna.py index 10dbd5e89..cae752861 100644 --- a/libraries/botbuilder-ai/tests/qna/test_qna.py +++ b/libraries/botbuilder-ai/tests/qna/test_qna.py @@ -1,963 +1,965 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -# pylint: disable=protected-access - -import json -from os import path -from typing import List, Dict -import unittest -from unittest.mock import patch -from aiohttp import ClientSession - -import aiounittest -from botbuilder.ai.qna import QnAMakerEndpoint, QnAMaker, QnAMakerOptions -from botbuilder.ai.qna.models import ( - FeedbackRecord, - Metadata, - QueryResult, - QnARequestContext, -) -from botbuilder.ai.qna.utils import QnATelemetryConstants -from botbuilder.core import BotAdapter, BotTelemetryClient, TurnContext -from botbuilder.core.adapters import TestAdapter -from botbuilder.schema import ( - Activity, - ActivityTypes, - ChannelAccount, - ConversationAccount, -) - - -class TestContext(TurnContext): - def __init__(self, request): - super().__init__(TestAdapter(), request) - self.sent: List[Activity] = list() - - self.on_send_activities(self.capture_sent_activities) - - async def capture_sent_activities( - self, context: TurnContext, activities, next - ): # pylint: disable=unused-argument - self.sent += activities - context.responded = True - - -class QnaApplicationTest(aiounittest.AsyncTestCase): - # Note this is NOT a real QnA Maker application ID nor a real QnA Maker subscription-key - # theses are GUIDs edited to look right to the parsing and validation code. - - _knowledge_base_id: str = "f028d9k3-7g9z-11d3-d300-2b8x98227q8w" - _endpoint_key: str = "1k997n7w-207z-36p3-j2u1-09tas20ci6011" - _host: str = "https://dummyqnahost.azurewebsites.net/qnamaker" - - tests_endpoint = QnAMakerEndpoint(_knowledge_base_id, _endpoint_key, _host) - - def test_qnamaker_construction(self): - # Arrange - endpoint = self.tests_endpoint - - # Act - qna = QnAMaker(endpoint) - endpoint = qna._endpoint - - # Assert - self.assertEqual( - "f028d9k3-7g9z-11d3-d300-2b8x98227q8w", endpoint.knowledge_base_id - ) - self.assertEqual("1k997n7w-207z-36p3-j2u1-09tas20ci6011", endpoint.endpoint_key) - self.assertEqual( - "https://dummyqnahost.azurewebsites.net/qnamaker", endpoint.host - ) - - def test_endpoint_with_empty_kbid(self): - empty_kbid = "" - - with self.assertRaises(TypeError): - QnAMakerEndpoint(empty_kbid, self._endpoint_key, self._host) - - def test_endpoint_with_empty_endpoint_key(self): - empty_endpoint_key = "" - - with self.assertRaises(TypeError): - QnAMakerEndpoint(self._knowledge_base_id, empty_endpoint_key, self._host) - - def test_endpoint_with_emptyhost(self): - with self.assertRaises(TypeError): - QnAMakerEndpoint(self._knowledge_base_id, self._endpoint_key, "") - - def test_qnamaker_with_none_endpoint(self): - with self.assertRaises(TypeError): - QnAMaker(None) - - def test_set_default_options_with_no_options_arg(self): - qna_without_options = QnAMaker(self.tests_endpoint) - options = qna_without_options._generate_answer_helper.options - - default_threshold = 0.3 - default_top = 1 - default_strict_filters = [] - - self.assertEqual(default_threshold, options.score_threshold) - self.assertEqual(default_top, options.top) - self.assertEqual(default_strict_filters, options.strict_filters) - - def test_options_passed_to_ctor(self): - options = QnAMakerOptions( - score_threshold=0.8, - timeout=9000, - top=5, - strict_filters=[Metadata("movie", "disney")], - ) - - qna_with_options = QnAMaker(self.tests_endpoint, options) - actual_options = qna_with_options._generate_answer_helper.options - - expected_threshold = 0.8 - expected_timeout = 9000 - expected_top = 5 - expected_strict_filters = [Metadata("movie", "disney")] - - self.assertEqual(expected_threshold, actual_options.score_threshold) - self.assertEqual(expected_timeout, actual_options.timeout) - self.assertEqual(expected_top, actual_options.top) - self.assertEqual( - expected_strict_filters[0].name, actual_options.strict_filters[0].name - ) - self.assertEqual( - expected_strict_filters[0].value, actual_options.strict_filters[0].value - ) - - async def test_returns_answer(self): - # Arrange - question: str = "how do I clean the stove?" - response_path: str = "ReturnsAnswer.json" - - # Act - result = await QnaApplicationTest._get_service_result(question, response_path) - - first_answer = result[0] - - # Assert - self.assertIsNotNone(result) - self.assertEqual(1, len(result)) - self.assertEqual( - "BaseCamp: You can use a damp rag to clean around the Power Pack", - first_answer.answer, - ) - - async def test_active_learning_enabled_status(self): - # Arrange - question: str = "how do I clean the stove?" - response_path: str = "ReturnsAnswer.json" - - # Act - result = await QnaApplicationTest._get_service_result_raw( - question, response_path - ) - - # Assert - self.assertIsNotNone(result) - self.assertEqual(1, len(result.answers)) - self.assertFalse(result.active_learning_enabled) - - async def test_returns_answer_using_options(self): - # Arrange - question: str = "up" - response_path: str = "AnswerWithOptions.json" - options = QnAMakerOptions( - score_threshold=0.8, top=5, strict_filters=[Metadata("movie", "disney")] - ) - - # Act - result = await QnaApplicationTest._get_service_result( - question, response_path, options=options - ) - - first_answer = result[0] - has_at_least_1_ans = True - first_metadata = first_answer.metadata[0] - - # Assert - self.assertIsNotNone(result) - self.assertEqual(has_at_least_1_ans, len(result) >= 1) - self.assertTrue(first_answer.answer[0]) - self.assertEqual("is a movie", first_answer.answer) - self.assertTrue(first_answer.score >= options.score_threshold) - self.assertEqual("movie", first_metadata.name) - self.assertEqual("disney", first_metadata.value) - - async def test_trace_test(self): - activity = Activity( - type=ActivityTypes.message, - text="how do I clean the stove?", - conversation=ConversationAccount(), - recipient=ChannelAccount(), - from_property=ChannelAccount(), - ) - - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - - context = TestContext(activity) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - result = await qna.get_answers(context) - - qna_trace_activities = list( - filter( - lambda act: act.type == "trace" and act.name == "QnAMaker", - context.sent, - ) - ) - trace_activity = qna_trace_activities[0] - - self.assertEqual("trace", trace_activity.type) - self.assertEqual("QnAMaker", trace_activity.name) - self.assertEqual("QnAMaker Trace", trace_activity.label) - self.assertEqual( - "https://www.qnamaker.ai/schemas/trace", trace_activity.value_type - ) - self.assertEqual(True, hasattr(trace_activity, "value")) - self.assertEqual(True, hasattr(trace_activity.value, "message")) - self.assertEqual(True, hasattr(trace_activity.value, "query_results")) - self.assertEqual(True, hasattr(trace_activity.value, "score_threshold")) - self.assertEqual(True, hasattr(trace_activity.value, "top")) - self.assertEqual(True, hasattr(trace_activity.value, "strict_filters")) - self.assertEqual( - self._knowledge_base_id, trace_activity.value.knowledge_base_id - ) - - return result - - async def test_returns_answer_with_timeout(self): - question: str = "how do I clean the stove?" - options = QnAMakerOptions(timeout=999999) - qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) - context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - result = await qna.get_answers(context, options) - - self.assertIsNotNone(result) - self.assertEqual( - options.timeout, qna._generate_answer_helper.options.timeout - ) - - async def test_telemetry_returns_answer(self): - # Arrange - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = True - context = QnaApplicationTest._get_context(question, TestAdapter()) - qna = QnAMaker( - QnaApplicationTest.tests_endpoint, - telemetry_client=telemetry_client, - log_personal_information=log_personal_information, - ) - - # Act - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context) - - telemetry_args = telemetry_client.track_event.call_args_list[0][1] - telemetry_properties = telemetry_args["properties"] - telemetry_metrics = telemetry_args["measurements"] - number_of_args = len(telemetry_args) - first_answer = telemetry_args["properties"][ - QnATelemetryConstants.answer_property - ] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - - # Assert - Check Telemetry logged. - self.assertEqual(1, telemetry_client.track_event.call_count) - self.assertEqual(3, number_of_args) - self.assertEqual("QnaMessage", telemetry_args["name"]) - self.assertTrue("answer" in telemetry_properties) - self.assertTrue("knowledgeBaseId" in telemetry_properties) - self.assertTrue("matchedQuestion" in telemetry_properties) - self.assertTrue("question" in telemetry_properties) - self.assertTrue("questionId" in telemetry_properties) - self.assertTrue("articleFound" in telemetry_properties) - self.assertEqual(expected_answer, first_answer) - self.assertTrue("score" in telemetry_metrics) - self.assertEqual(1, telemetry_metrics["score"]) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_telemetry_returns_answer_when_no_answer_found_in_kb(self): - # Arrange - question: str = "gibberish question" - response_json = QnaApplicationTest._get_json_for_file("NoAnswerFoundInKb.json") - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - qna = QnAMaker( - QnaApplicationTest.tests_endpoint, - telemetry_client=telemetry_client, - log_personal_information=True, - ) - context = QnaApplicationTest._get_context(question, TestAdapter()) - - # Act - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context) - - telemetry_args = telemetry_client.track_event.call_args_list[0][1] - telemetry_properties = telemetry_args["properties"] - number_of_args = len(telemetry_args) - first_answer = telemetry_args["properties"][ - QnATelemetryConstants.answer_property - ] - expected_answer = "No Qna Answer matched" - expected_matched_question = "No Qna Question matched" - - # Assert - Check Telemetry logged. - self.assertEqual(1, telemetry_client.track_event.call_count) - self.assertEqual(3, number_of_args) - self.assertEqual("QnaMessage", telemetry_args["name"]) - self.assertTrue("answer" in telemetry_properties) - self.assertTrue("knowledgeBaseId" in telemetry_properties) - self.assertTrue("matchedQuestion" in telemetry_properties) - self.assertEqual( - expected_matched_question, - telemetry_properties[QnATelemetryConstants.matched_question_property], - ) - self.assertTrue("question" in telemetry_properties) - self.assertTrue("questionId" in telemetry_properties) - self.assertTrue("articleFound" in telemetry_properties) - self.assertEqual(expected_answer, first_answer) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(0, len(results)) - - async def test_telemetry_pii(self): - # Arrange - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = False - context = QnaApplicationTest._get_context(question, TestAdapter()) - qna = QnAMaker( - QnaApplicationTest.tests_endpoint, - telemetry_client=telemetry_client, - log_personal_information=log_personal_information, - ) - - # Act - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context) - - telemetry_args = telemetry_client.track_event.call_args_list[0][1] - telemetry_properties = telemetry_args["properties"] - telemetry_metrics = telemetry_args["measurements"] - number_of_args = len(telemetry_args) - first_answer = telemetry_args["properties"][ - QnATelemetryConstants.answer_property - ] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - - # Assert - Validate PII properties not logged. - self.assertEqual(1, telemetry_client.track_event.call_count) - self.assertEqual(3, number_of_args) - self.assertEqual("QnaMessage", telemetry_args["name"]) - self.assertTrue("answer" in telemetry_properties) - self.assertTrue("knowledgeBaseId" in telemetry_properties) - self.assertTrue("matchedQuestion" in telemetry_properties) - self.assertTrue("question" not in telemetry_properties) - self.assertTrue("questionId" in telemetry_properties) - self.assertTrue("articleFound" in telemetry_properties) - self.assertEqual(expected_answer, first_answer) - self.assertTrue("score" in telemetry_metrics) - self.assertEqual(1, telemetry_metrics["score"]) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_telemetry_override(self): - # Arrange - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - context = QnaApplicationTest._get_context(question, TestAdapter()) - options = QnAMakerOptions(top=1) - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = False - - # Act - Override the QnAMaker object to log custom stuff and honor params passed in. - telemetry_properties: Dict[str, str] = {"id": "MyId"} - qna = QnaApplicationTest.OverrideTelemetry( - QnaApplicationTest.tests_endpoint, - options, - None, - telemetry_client, - log_personal_information, - ) - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context, options, telemetry_properties) - - telemetry_args = telemetry_client.track_event.call_args_list - first_call_args = telemetry_args[0][0] - first_call_properties = first_call_args[1] - second_call_args = telemetry_args[1][0] - second_call_properties = second_call_args[1] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - - # Assert - self.assertEqual(2, telemetry_client.track_event.call_count) - self.assertEqual(2, len(first_call_args)) - self.assertEqual("QnaMessage", first_call_args[0]) - self.assertEqual(2, len(first_call_properties)) - self.assertTrue("my_important_property" in first_call_properties) - self.assertEqual( - "my_important_value", first_call_properties["my_important_property"] - ) - self.assertTrue("id" in first_call_properties) - self.assertEqual("MyId", first_call_properties["id"]) - - self.assertEqual("my_second_event", second_call_args[0]) - self.assertTrue("my_important_property2" in second_call_properties) - self.assertEqual( - "my_important_value2", second_call_properties["my_important_property2"] - ) - - # Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_telemetry_additional_props_metrics(self): - # Arrange - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - context = QnaApplicationTest._get_context(question, TestAdapter()) - options = QnAMakerOptions(top=1) - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = False - - # Act - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - qna = QnAMaker( - QnaApplicationTest.tests_endpoint, - options, - None, - telemetry_client, - log_personal_information, - ) - telemetry_properties: Dict[str, str] = { - "my_important_property": "my_important_value" - } - telemetry_metrics: Dict[str, float] = {"my_important_metric": 3.14159} - - results = await qna.get_answers( - context, None, telemetry_properties, telemetry_metrics - ) - - # Assert - Added properties were added. - telemetry_args = telemetry_client.track_event.call_args_list[0][1] - telemetry_properties = telemetry_args["properties"] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - - self.assertEqual(1, telemetry_client.track_event.call_count) - self.assertEqual(3, len(telemetry_args)) - self.assertEqual("QnaMessage", telemetry_args["name"]) - self.assertTrue("knowledgeBaseId" in telemetry_properties) - self.assertTrue("question" not in telemetry_properties) - self.assertTrue("matchedQuestion" in telemetry_properties) - self.assertTrue("questionId" in telemetry_properties) - self.assertTrue("answer" in telemetry_properties) - self.assertTrue(expected_answer, telemetry_properties["answer"]) - self.assertTrue("my_important_property" in telemetry_properties) - self.assertEqual( - "my_important_value", telemetry_properties["my_important_property"] - ) - - tracked_metrics = telemetry_args["measurements"] - - self.assertEqual(2, len(tracked_metrics)) - self.assertTrue("score" in tracked_metrics) - self.assertTrue("my_important_metric" in tracked_metrics) - self.assertEqual(3.14159, tracked_metrics["my_important_metric"]) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_telemetry_additional_props_override(self): - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - context = QnaApplicationTest._get_context(question, TestAdapter()) - options = QnAMakerOptions(top=1) - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = False - - # Act - Pass in properties during QnA invocation that override default properties - # NOTE: We are invoking this with PII turned OFF, and passing a PII property (originalQuestion). - qna = QnAMaker( - QnaApplicationTest.tests_endpoint, - options, - None, - telemetry_client, - log_personal_information, - ) - telemetry_properties = { - "knowledge_base_id": "my_important_value", - "original_question": "my_important_value2", - } - telemetry_metrics = {"score": 3.14159} - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers( - context, None, telemetry_properties, telemetry_metrics - ) - - # Assert - Added properties were added. - tracked_args = telemetry_client.track_event.call_args_list[0][1] - tracked_properties = tracked_args["properties"] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - tracked_metrics = tracked_args["measurements"] - - self.assertEqual(1, telemetry_client.track_event.call_count) - self.assertEqual(3, len(tracked_args)) - self.assertEqual("QnaMessage", tracked_args["name"]) - self.assertTrue("knowledge_base_id" in tracked_properties) - self.assertEqual( - "my_important_value", tracked_properties["knowledge_base_id"] - ) - self.assertTrue("original_question" in tracked_properties) - self.assertTrue("matchedQuestion" in tracked_properties) - self.assertEqual( - "my_important_value2", tracked_properties["original_question"] - ) - self.assertTrue("question" not in tracked_properties) - self.assertTrue("questionId" in tracked_properties) - self.assertTrue("answer" in tracked_properties) - self.assertEqual(expected_answer, tracked_properties["answer"]) - self.assertTrue("my_important_property" not in tracked_properties) - self.assertEqual(1, len(tracked_metrics)) - self.assertTrue("score" in tracked_metrics) - self.assertEqual(3.14159, tracked_metrics["score"]) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_telemetry_fill_props_override(self): - # Arrange - question: str = "how do I clean the stove?" - response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") - context: TurnContext = QnaApplicationTest._get_context(question, TestAdapter()) - options = QnAMakerOptions(top=1) - telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) - log_personal_information = False - - # Act - Pass in properties during QnA invocation that override default properties - # In addition Override with derivation. This presents an interesting question of order of setting - # properties. - # If I want to override "originalQuestion" property: - # - Set in "Stock" schema - # - Set in derived QnAMaker class - # - Set in GetAnswersAsync - # Logically, the GetAnswersAync should win. But ultimately OnQnaResultsAsync decides since it is the last - # code to touch the properties before logging (since it actually logs the event). - qna = QnaApplicationTest.OverrideFillTelemetry( - QnaApplicationTest.tests_endpoint, - options, - None, - telemetry_client, - log_personal_information, - ) - telemetry_properties: Dict[str, str] = { - "knowledgeBaseId": "my_important_value", - "matchedQuestion": "my_important_value2", - } - telemetry_metrics: Dict[str, float] = {"score": 3.14159} - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers( - context, None, telemetry_properties, telemetry_metrics - ) - - # Assert - Added properties were added. - first_call_args = telemetry_client.track_event.call_args_list[0][0] - first_properties = first_call_args[1] - expected_answer = ( - "BaseCamp: You can use a damp rag to clean around the Power Pack" - ) - first_metrics = first_call_args[2] - - self.assertEqual(2, telemetry_client.track_event.call_count) - self.assertEqual(3, len(first_call_args)) - self.assertEqual("QnaMessage", first_call_args[0]) - self.assertEqual(6, len(first_properties)) - self.assertTrue("knowledgeBaseId" in first_properties) - self.assertEqual("my_important_value", first_properties["knowledgeBaseId"]) - self.assertTrue("matchedQuestion" in first_properties) - self.assertEqual("my_important_value2", first_properties["matchedQuestion"]) - self.assertTrue("questionId" in first_properties) - self.assertTrue("answer" in first_properties) - self.assertEqual(expected_answer, first_properties["answer"]) - self.assertTrue("articleFound" in first_properties) - self.assertTrue("my_important_property" in first_properties) - self.assertEqual( - "my_important_value", first_properties["my_important_property"] - ) - - self.assertEqual(1, len(first_metrics)) - self.assertTrue("score" in first_metrics) - self.assertEqual(3.14159, first_metrics["score"]) - - # Assert - Validate we didn't break QnA functionality. - self.assertIsNotNone(results) - self.assertEqual(1, len(results)) - self.assertEqual(expected_answer, results[0].answer) - self.assertEqual("Editorial", results[0].source) - - async def test_call_train(self): - feedback_records = [] - - feedback1 = FeedbackRecord( - qna_id=1, user_id="test", user_question="How are you?" - ) - - feedback2 = FeedbackRecord(qna_id=2, user_id="test", user_question="What up??") - - feedback_records.extend([feedback1, feedback2]) - - with patch.object( - QnAMaker, "call_train", return_value=None - ) as mocked_call_train: - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - qna.call_train(feedback_records) - - mocked_call_train.assert_called_once_with(feedback_records) - - async def test_should_filter_low_score_variation(self): - options = QnAMakerOptions(top=5) - qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) - question: str = "Q11" - context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file("TopNAnswer.json") - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context) - self.assertEqual(4, len(results), "Should have received 4 answers.") - - filtered_results = qna.get_low_score_variation(results) - self.assertEqual( - 3, - len(filtered_results), - "Should have 3 filtered answers after low score variation.", - ) - - async def test_should_answer_with_is_test_true(self): - options = QnAMakerOptions(top=1, is_test=True) - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - question: str = "Q11" - context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file( - "QnaMaker_IsTest_true.json" - ) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context, options=options) - self.assertEqual(0, len(results), "Should have received zero answer.") - - async def test_should_answer_with_ranker_type_question_only(self): - options = QnAMakerOptions(top=1, ranker_type="QuestionOnly") - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - question: str = "Q11" - context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file( - "QnaMaker_RankerType_QuestionOnly.json" - ) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(context, options=options) - self.assertEqual(2, len(results), "Should have received two answers.") - - async def test_should_answer_with_prompts(self): - options = QnAMakerOptions(top=2) - qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) - question: str = "how do I clean the stove?" - turn_context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file("AnswerWithPrompts.json") - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(turn_context, options) - self.assertEqual(1, len(results), "Should have received 1 answers.") - self.assertEqual( - 1, len(results[0].context.prompts), "Should have received 1 prompt." - ) - - async def test_should_answer_with_high_score_provided_context(self): - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - question: str = "where can I buy?" - context = QnARequestContext( - previous_qna_id=5, prvious_user_query="how do I clean the stove?" - ) - options = QnAMakerOptions(top=2, qna_id=55, context=context) - turn_context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file( - "AnswerWithHighScoreProvidedContext.json" - ) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(turn_context, options) - self.assertEqual(1, len(results), "Should have received 1 answers.") - self.assertEqual(1, results[0].score, "Score should be high.") - - async def test_should_answer_with_high_score_provided_qna_id(self): - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - question: str = "where can I buy?" - - options = QnAMakerOptions(top=2, qna_id=55) - turn_context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file( - "AnswerWithHighScoreProvidedContext.json" - ) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(turn_context, options) - self.assertEqual(1, len(results), "Should have received 1 answers.") - self.assertEqual(1, results[0].score, "Score should be high.") - - async def test_should_answer_with_low_score_without_provided_context(self): - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - question: str = "where can I buy?" - options = QnAMakerOptions(top=2, context=None) - - turn_context = QnaApplicationTest._get_context(question, TestAdapter()) - response_json = QnaApplicationTest._get_json_for_file( - "AnswerWithLowScoreProvidedWithoutContext.json" - ) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - results = await qna.get_answers(turn_context, options) - self.assertEqual( - 2, len(results), "Should have received more than one answers." - ) - self.assertEqual(True, results[0].score < 1, "Score should be low.") - - @classmethod - async def _get_service_result( - cls, - utterance: str, - response_file: str, - bot_adapter: BotAdapter = TestAdapter(), - options: QnAMakerOptions = None, - ) -> [dict]: - response_json = QnaApplicationTest._get_json_for_file(response_file) - - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - context = QnaApplicationTest._get_context(utterance, bot_adapter) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - result = await qna.get_answers(context, options) - - return result - - @classmethod - async def _get_service_result_raw( - cls, - utterance: str, - response_file: str, - bot_adapter: BotAdapter = TestAdapter(), - options: QnAMakerOptions = None, - ) -> [dict]: - response_json = QnaApplicationTest._get_json_for_file(response_file) - - qna = QnAMaker(QnaApplicationTest.tests_endpoint) - context = QnaApplicationTest._get_context(utterance, bot_adapter) - - with patch( - "aiohttp.ClientSession.post", - return_value=aiounittest.futurized(response_json), - ): - result = await qna.get_answers_raw(context, options) - - return result - - @classmethod - def _get_json_for_file(cls, response_file: str) -> object: - curr_dir = path.dirname(path.abspath(__file__)) - response_path = path.join(curr_dir, "test_data", response_file) - - with open(response_path, "r", encoding="utf-8-sig") as file: - response_str = file.read() - response_json = json.loads(response_str) - - return response_json - - @staticmethod - def _get_context(question: str, bot_adapter: BotAdapter) -> TurnContext: - test_adapter = bot_adapter or TestAdapter() - activity = Activity( - type=ActivityTypes.message, - text=question, - conversation=ConversationAccount(), - recipient=ChannelAccount(), - from_property=ChannelAccount(), - ) - - return TurnContext(test_adapter, activity) - - class OverrideTelemetry(QnAMaker): - def __init__( # pylint: disable=useless-super-delegation - self, - endpoint: QnAMakerEndpoint, - options: QnAMakerOptions, - http_client: ClientSession, - telemetry_client: BotTelemetryClient, - log_personal_information: bool, - ): - super().__init__( - endpoint, - options, - http_client, - telemetry_client, - log_personal_information, - ) - - async def on_qna_result( # pylint: disable=unused-argument - self, - query_results: [QueryResult], - turn_context: TurnContext, - telemetry_properties: Dict[str, str] = None, - telemetry_metrics: Dict[str, float] = None, - ): - properties = telemetry_properties or {} - - # get_answers overrides derived class - properties["my_important_property"] = "my_important_value" - - # Log event - self.telemetry_client.track_event( - QnATelemetryConstants.qna_message_event, properties - ) - - # Create 2nd event. - second_event_properties = {"my_important_property2": "my_important_value2"} - self.telemetry_client.track_event( - "my_second_event", second_event_properties - ) - - class OverrideFillTelemetry(QnAMaker): - def __init__( # pylint: disable=useless-super-delegation - self, - endpoint: QnAMakerEndpoint, - options: QnAMakerOptions, - http_client: ClientSession, - telemetry_client: BotTelemetryClient, - log_personal_information: bool, - ): - super().__init__( - endpoint, - options, - http_client, - telemetry_client, - log_personal_information, - ) - - async def on_qna_result( - self, - query_results: [QueryResult], - turn_context: TurnContext, - telemetry_properties: Dict[str, str] = None, - telemetry_metrics: Dict[str, float] = None, - ): - event_data = await self.fill_qna_event( - query_results, turn_context, telemetry_properties, telemetry_metrics - ) - - # Add my property. - event_data.properties.update( - {"my_important_property": "my_important_value"} - ) - - # Log QnaMessage event. - self.telemetry_client.track_event( - QnATelemetryConstants.qna_message_event, - event_data.properties, - event_data.metrics, - ) - - # Create second event. - second_event_properties: Dict[str, str] = { - "my_important_property2": "my_important_value2" - } - - self.telemetry_client.track_event("MySecondEvent", second_event_properties) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# pylint: disable=protected-access + +import json +from os import path +from typing import List, Dict +import unittest +from unittest.mock import patch +from aiohttp import ClientSession + +import aiounittest +from botbuilder.ai.qna import QnAMakerEndpoint, QnAMaker, QnAMakerOptions +from botbuilder.ai.qna.models import ( + FeedbackRecord, + Metadata, + QueryResult, + QnARequestContext, +) +from botbuilder.ai.qna.utils import QnATelemetryConstants +from botbuilder.core import BotAdapter, BotTelemetryClient, TurnContext +from botbuilder.core.adapters import TestAdapter +from botbuilder.schema import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationAccount, +) + + +class TestContext(TurnContext): + __test__ = False + + def __init__(self, request): + super().__init__(TestAdapter(), request) + self.sent: List[Activity] = list() + + self.on_send_activities(self.capture_sent_activities) + + async def capture_sent_activities( + self, context: TurnContext, activities, next + ): # pylint: disable=unused-argument + self.sent += activities + context.responded = True + + +class QnaApplicationTest(aiounittest.AsyncTestCase): + # Note this is NOT a real QnA Maker application ID nor a real QnA Maker subscription-key + # theses are GUIDs edited to look right to the parsing and validation code. + + _knowledge_base_id: str = "f028d9k3-7g9z-11d3-d300-2b8x98227q8w" + _endpoint_key: str = "1k997n7w-207z-36p3-j2u1-09tas20ci6011" + _host: str = "https://dummyqnahost.azurewebsites.net/qnamaker" + + tests_endpoint = QnAMakerEndpoint(_knowledge_base_id, _endpoint_key, _host) + + def test_qnamaker_construction(self): + # Arrange + endpoint = self.tests_endpoint + + # Act + qna = QnAMaker(endpoint) + endpoint = qna._endpoint + + # Assert + self.assertEqual( + "f028d9k3-7g9z-11d3-d300-2b8x98227q8w", endpoint.knowledge_base_id + ) + self.assertEqual("1k997n7w-207z-36p3-j2u1-09tas20ci6011", endpoint.endpoint_key) + self.assertEqual( + "https://dummyqnahost.azurewebsites.net/qnamaker", endpoint.host + ) + + def test_endpoint_with_empty_kbid(self): + empty_kbid = "" + + with self.assertRaises(TypeError): + QnAMakerEndpoint(empty_kbid, self._endpoint_key, self._host) + + def test_endpoint_with_empty_endpoint_key(self): + empty_endpoint_key = "" + + with self.assertRaises(TypeError): + QnAMakerEndpoint(self._knowledge_base_id, empty_endpoint_key, self._host) + + def test_endpoint_with_emptyhost(self): + with self.assertRaises(TypeError): + QnAMakerEndpoint(self._knowledge_base_id, self._endpoint_key, "") + + def test_qnamaker_with_none_endpoint(self): + with self.assertRaises(TypeError): + QnAMaker(None) + + def test_set_default_options_with_no_options_arg(self): + qna_without_options = QnAMaker(self.tests_endpoint) + options = qna_without_options._generate_answer_helper.options + + default_threshold = 0.3 + default_top = 1 + default_strict_filters = [] + + self.assertEqual(default_threshold, options.score_threshold) + self.assertEqual(default_top, options.top) + self.assertEqual(default_strict_filters, options.strict_filters) + + def test_options_passed_to_ctor(self): + options = QnAMakerOptions( + score_threshold=0.8, + timeout=9000, + top=5, + strict_filters=[Metadata("movie", "disney")], + ) + + qna_with_options = QnAMaker(self.tests_endpoint, options) + actual_options = qna_with_options._generate_answer_helper.options + + expected_threshold = 0.8 + expected_timeout = 9000 + expected_top = 5 + expected_strict_filters = [Metadata("movie", "disney")] + + self.assertEqual(expected_threshold, actual_options.score_threshold) + self.assertEqual(expected_timeout, actual_options.timeout) + self.assertEqual(expected_top, actual_options.top) + self.assertEqual( + expected_strict_filters[0].name, actual_options.strict_filters[0].name + ) + self.assertEqual( + expected_strict_filters[0].value, actual_options.strict_filters[0].value + ) + + async def test_returns_answer(self): + # Arrange + question: str = "how do I clean the stove?" + response_path: str = "ReturnsAnswer.json" + + # Act + result = await QnaApplicationTest._get_service_result(question, response_path) + + first_answer = result[0] + + # Assert + self.assertIsNotNone(result) + self.assertEqual(1, len(result)) + self.assertEqual( + "BaseCamp: You can use a damp rag to clean around the Power Pack", + first_answer.answer, + ) + + async def test_active_learning_enabled_status(self): + # Arrange + question: str = "how do I clean the stove?" + response_path: str = "ReturnsAnswer.json" + + # Act + result = await QnaApplicationTest._get_service_result_raw( + question, response_path + ) + + # Assert + self.assertIsNotNone(result) + self.assertEqual(1, len(result.answers)) + self.assertFalse(result.active_learning_enabled) + + async def test_returns_answer_using_options(self): + # Arrange + question: str = "up" + response_path: str = "AnswerWithOptions.json" + options = QnAMakerOptions( + score_threshold=0.8, top=5, strict_filters=[Metadata("movie", "disney")] + ) + + # Act + result = await QnaApplicationTest._get_service_result( + question, response_path, options=options + ) + + first_answer = result[0] + has_at_least_1_ans = True + first_metadata = first_answer.metadata[0] + + # Assert + self.assertIsNotNone(result) + self.assertEqual(has_at_least_1_ans, len(result) >= 1) + self.assertTrue(first_answer.answer[0]) + self.assertEqual("is a movie", first_answer.answer) + self.assertTrue(first_answer.score >= options.score_threshold) + self.assertEqual("movie", first_metadata.name) + self.assertEqual("disney", first_metadata.value) + + async def test_trace_test(self): + activity = Activity( + type=ActivityTypes.message, + text="how do I clean the stove?", + conversation=ConversationAccount(), + recipient=ChannelAccount(), + from_property=ChannelAccount(), + ) + + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + + context = TestContext(activity) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + result = await qna.get_answers(context) + + qna_trace_activities = list( + filter( + lambda act: act.type == "trace" and act.name == "QnAMaker", + context.sent, + ) + ) + trace_activity = qna_trace_activities[0] + + self.assertEqual("trace", trace_activity.type) + self.assertEqual("QnAMaker", trace_activity.name) + self.assertEqual("QnAMaker Trace", trace_activity.label) + self.assertEqual( + "https://www.qnamaker.ai/schemas/trace", trace_activity.value_type + ) + self.assertEqual(True, hasattr(trace_activity, "value")) + self.assertEqual(True, hasattr(trace_activity.value, "message")) + self.assertEqual(True, hasattr(trace_activity.value, "query_results")) + self.assertEqual(True, hasattr(trace_activity.value, "score_threshold")) + self.assertEqual(True, hasattr(trace_activity.value, "top")) + self.assertEqual(True, hasattr(trace_activity.value, "strict_filters")) + self.assertEqual( + self._knowledge_base_id, trace_activity.value.knowledge_base_id + ) + + return result + + async def test_returns_answer_with_timeout(self): + question: str = "how do I clean the stove?" + options = QnAMakerOptions(timeout=999999) + qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) + context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + result = await qna.get_answers(context, options) + + self.assertIsNotNone(result) + self.assertEqual( + options.timeout, qna._generate_answer_helper.options.timeout + ) + + async def test_telemetry_returns_answer(self): + # Arrange + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = True + context = QnaApplicationTest._get_context(question, TestAdapter()) + qna = QnAMaker( + QnaApplicationTest.tests_endpoint, + telemetry_client=telemetry_client, + log_personal_information=log_personal_information, + ) + + # Act + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context) + + telemetry_args = telemetry_client.track_event.call_args_list[0][1] + telemetry_properties = telemetry_args["properties"] + telemetry_metrics = telemetry_args["measurements"] + number_of_args = len(telemetry_args) + first_answer = telemetry_args["properties"][ + QnATelemetryConstants.answer_property + ] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + + # Assert - Check Telemetry logged. + self.assertEqual(1, telemetry_client.track_event.call_count) + self.assertEqual(3, number_of_args) + self.assertEqual("QnaMessage", telemetry_args["name"]) + self.assertTrue("answer" in telemetry_properties) + self.assertTrue("knowledgeBaseId" in telemetry_properties) + self.assertTrue("matchedQuestion" in telemetry_properties) + self.assertTrue("question" in telemetry_properties) + self.assertTrue("questionId" in telemetry_properties) + self.assertTrue("articleFound" in telemetry_properties) + self.assertEqual(expected_answer, first_answer) + self.assertTrue("score" in telemetry_metrics) + self.assertEqual(1, telemetry_metrics["score"]) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_telemetry_returns_answer_when_no_answer_found_in_kb(self): + # Arrange + question: str = "gibberish question" + response_json = QnaApplicationTest._get_json_for_file("NoAnswerFoundInKb.json") + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + qna = QnAMaker( + QnaApplicationTest.tests_endpoint, + telemetry_client=telemetry_client, + log_personal_information=True, + ) + context = QnaApplicationTest._get_context(question, TestAdapter()) + + # Act + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context) + + telemetry_args = telemetry_client.track_event.call_args_list[0][1] + telemetry_properties = telemetry_args["properties"] + number_of_args = len(telemetry_args) + first_answer = telemetry_args["properties"][ + QnATelemetryConstants.answer_property + ] + expected_answer = "No Qna Answer matched" + expected_matched_question = "No Qna Question matched" + + # Assert - Check Telemetry logged. + self.assertEqual(1, telemetry_client.track_event.call_count) + self.assertEqual(3, number_of_args) + self.assertEqual("QnaMessage", telemetry_args["name"]) + self.assertTrue("answer" in telemetry_properties) + self.assertTrue("knowledgeBaseId" in telemetry_properties) + self.assertTrue("matchedQuestion" in telemetry_properties) + self.assertEqual( + expected_matched_question, + telemetry_properties[QnATelemetryConstants.matched_question_property], + ) + self.assertTrue("question" in telemetry_properties) + self.assertTrue("questionId" in telemetry_properties) + self.assertTrue("articleFound" in telemetry_properties) + self.assertEqual(expected_answer, first_answer) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(0, len(results)) + + async def test_telemetry_pii(self): + # Arrange + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = False + context = QnaApplicationTest._get_context(question, TestAdapter()) + qna = QnAMaker( + QnaApplicationTest.tests_endpoint, + telemetry_client=telemetry_client, + log_personal_information=log_personal_information, + ) + + # Act + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context) + + telemetry_args = telemetry_client.track_event.call_args_list[0][1] + telemetry_properties = telemetry_args["properties"] + telemetry_metrics = telemetry_args["measurements"] + number_of_args = len(telemetry_args) + first_answer = telemetry_args["properties"][ + QnATelemetryConstants.answer_property + ] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + + # Assert - Validate PII properties not logged. + self.assertEqual(1, telemetry_client.track_event.call_count) + self.assertEqual(3, number_of_args) + self.assertEqual("QnaMessage", telemetry_args["name"]) + self.assertTrue("answer" in telemetry_properties) + self.assertTrue("knowledgeBaseId" in telemetry_properties) + self.assertTrue("matchedQuestion" in telemetry_properties) + self.assertTrue("question" not in telemetry_properties) + self.assertTrue("questionId" in telemetry_properties) + self.assertTrue("articleFound" in telemetry_properties) + self.assertEqual(expected_answer, first_answer) + self.assertTrue("score" in telemetry_metrics) + self.assertEqual(1, telemetry_metrics["score"]) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_telemetry_override(self): + # Arrange + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + context = QnaApplicationTest._get_context(question, TestAdapter()) + options = QnAMakerOptions(top=1) + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = False + + # Act - Override the QnAMaker object to log custom stuff and honor params passed in. + telemetry_properties: Dict[str, str] = {"id": "MyId"} + qna = QnaApplicationTest.OverrideTelemetry( + QnaApplicationTest.tests_endpoint, + options, + None, + telemetry_client, + log_personal_information, + ) + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context, options, telemetry_properties) + + telemetry_args = telemetry_client.track_event.call_args_list + first_call_args = telemetry_args[0][0] + first_call_properties = first_call_args[1] + second_call_args = telemetry_args[1][0] + second_call_properties = second_call_args[1] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + + # Assert + self.assertEqual(2, telemetry_client.track_event.call_count) + self.assertEqual(2, len(first_call_args)) + self.assertEqual("QnaMessage", first_call_args[0]) + self.assertEqual(2, len(first_call_properties)) + self.assertTrue("my_important_property" in first_call_properties) + self.assertEqual( + "my_important_value", first_call_properties["my_important_property"] + ) + self.assertTrue("id" in first_call_properties) + self.assertEqual("MyId", first_call_properties["id"]) + + self.assertEqual("my_second_event", second_call_args[0]) + self.assertTrue("my_important_property2" in second_call_properties) + self.assertEqual( + "my_important_value2", second_call_properties["my_important_property2"] + ) + + # Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_telemetry_additional_props_metrics(self): + # Arrange + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + context = QnaApplicationTest._get_context(question, TestAdapter()) + options = QnAMakerOptions(top=1) + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = False + + # Act + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + qna = QnAMaker( + QnaApplicationTest.tests_endpoint, + options, + None, + telemetry_client, + log_personal_information, + ) + telemetry_properties: Dict[str, str] = { + "my_important_property": "my_important_value" + } + telemetry_metrics: Dict[str, float] = {"my_important_metric": 3.14159} + + results = await qna.get_answers( + context, None, telemetry_properties, telemetry_metrics + ) + + # Assert - Added properties were added. + telemetry_args = telemetry_client.track_event.call_args_list[0][1] + telemetry_properties = telemetry_args["properties"] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + + self.assertEqual(1, telemetry_client.track_event.call_count) + self.assertEqual(3, len(telemetry_args)) + self.assertEqual("QnaMessage", telemetry_args["name"]) + self.assertTrue("knowledgeBaseId" in telemetry_properties) + self.assertTrue("question" not in telemetry_properties) + self.assertTrue("matchedQuestion" in telemetry_properties) + self.assertTrue("questionId" in telemetry_properties) + self.assertTrue("answer" in telemetry_properties) + self.assertTrue(expected_answer, telemetry_properties["answer"]) + self.assertTrue("my_important_property" in telemetry_properties) + self.assertEqual( + "my_important_value", telemetry_properties["my_important_property"] + ) + + tracked_metrics = telemetry_args["measurements"] + + self.assertEqual(2, len(tracked_metrics)) + self.assertTrue("score" in tracked_metrics) + self.assertTrue("my_important_metric" in tracked_metrics) + self.assertEqual(3.14159, tracked_metrics["my_important_metric"]) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_telemetry_additional_props_override(self): + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + context = QnaApplicationTest._get_context(question, TestAdapter()) + options = QnAMakerOptions(top=1) + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = False + + # Act - Pass in properties during QnA invocation that override default properties + # NOTE: We are invoking this with PII turned OFF, and passing a PII property (originalQuestion). + qna = QnAMaker( + QnaApplicationTest.tests_endpoint, + options, + None, + telemetry_client, + log_personal_information, + ) + telemetry_properties = { + "knowledge_base_id": "my_important_value", + "original_question": "my_important_value2", + } + telemetry_metrics = {"score": 3.14159} + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers( + context, None, telemetry_properties, telemetry_metrics + ) + + # Assert - Added properties were added. + tracked_args = telemetry_client.track_event.call_args_list[0][1] + tracked_properties = tracked_args["properties"] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + tracked_metrics = tracked_args["measurements"] + + self.assertEqual(1, telemetry_client.track_event.call_count) + self.assertEqual(3, len(tracked_args)) + self.assertEqual("QnaMessage", tracked_args["name"]) + self.assertTrue("knowledge_base_id" in tracked_properties) + self.assertEqual( + "my_important_value", tracked_properties["knowledge_base_id"] + ) + self.assertTrue("original_question" in tracked_properties) + self.assertTrue("matchedQuestion" in tracked_properties) + self.assertEqual( + "my_important_value2", tracked_properties["original_question"] + ) + self.assertTrue("question" not in tracked_properties) + self.assertTrue("questionId" in tracked_properties) + self.assertTrue("answer" in tracked_properties) + self.assertEqual(expected_answer, tracked_properties["answer"]) + self.assertTrue("my_important_property" not in tracked_properties) + self.assertEqual(1, len(tracked_metrics)) + self.assertTrue("score" in tracked_metrics) + self.assertEqual(3.14159, tracked_metrics["score"]) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_telemetry_fill_props_override(self): + # Arrange + question: str = "how do I clean the stove?" + response_json = QnaApplicationTest._get_json_for_file("ReturnsAnswer.json") + context: TurnContext = QnaApplicationTest._get_context(question, TestAdapter()) + options = QnAMakerOptions(top=1) + telemetry_client = unittest.mock.create_autospec(BotTelemetryClient) + log_personal_information = False + + # Act - Pass in properties during QnA invocation that override default properties + # In addition Override with derivation. This presents an interesting question of order of setting + # properties. + # If I want to override "originalQuestion" property: + # - Set in "Stock" schema + # - Set in derived QnAMaker class + # - Set in GetAnswersAsync + # Logically, the GetAnswersAync should win. But ultimately OnQnaResultsAsync decides since it is the last + # code to touch the properties before logging (since it actually logs the event). + qna = QnaApplicationTest.OverrideFillTelemetry( + QnaApplicationTest.tests_endpoint, + options, + None, + telemetry_client, + log_personal_information, + ) + telemetry_properties: Dict[str, str] = { + "knowledgeBaseId": "my_important_value", + "matchedQuestion": "my_important_value2", + } + telemetry_metrics: Dict[str, float] = {"score": 3.14159} + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers( + context, None, telemetry_properties, telemetry_metrics + ) + + # Assert - Added properties were added. + first_call_args = telemetry_client.track_event.call_args_list[0][0] + first_properties = first_call_args[1] + expected_answer = ( + "BaseCamp: You can use a damp rag to clean around the Power Pack" + ) + first_metrics = first_call_args[2] + + self.assertEqual(2, telemetry_client.track_event.call_count) + self.assertEqual(3, len(first_call_args)) + self.assertEqual("QnaMessage", first_call_args[0]) + self.assertEqual(6, len(first_properties)) + self.assertTrue("knowledgeBaseId" in first_properties) + self.assertEqual("my_important_value", first_properties["knowledgeBaseId"]) + self.assertTrue("matchedQuestion" in first_properties) + self.assertEqual("my_important_value2", first_properties["matchedQuestion"]) + self.assertTrue("questionId" in first_properties) + self.assertTrue("answer" in first_properties) + self.assertEqual(expected_answer, first_properties["answer"]) + self.assertTrue("articleFound" in first_properties) + self.assertTrue("my_important_property" in first_properties) + self.assertEqual( + "my_important_value", first_properties["my_important_property"] + ) + + self.assertEqual(1, len(first_metrics)) + self.assertTrue("score" in first_metrics) + self.assertEqual(3.14159, first_metrics["score"]) + + # Assert - Validate we didn't break QnA functionality. + self.assertIsNotNone(results) + self.assertEqual(1, len(results)) + self.assertEqual(expected_answer, results[0].answer) + self.assertEqual("Editorial", results[0].source) + + async def test_call_train(self): + feedback_records = [] + + feedback1 = FeedbackRecord( + qna_id=1, user_id="test", user_question="How are you?" + ) + + feedback2 = FeedbackRecord(qna_id=2, user_id="test", user_question="What up??") + + feedback_records.extend([feedback1, feedback2]) + + with patch.object( + QnAMaker, "call_train", return_value=None + ) as mocked_call_train: + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + qna.call_train(feedback_records) + + mocked_call_train.assert_called_once_with(feedback_records) + + async def test_should_filter_low_score_variation(self): + options = QnAMakerOptions(top=5) + qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) + question: str = "Q11" + context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file("TopNAnswer.json") + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context) + self.assertEqual(4, len(results), "Should have received 4 answers.") + + filtered_results = qna.get_low_score_variation(results) + self.assertEqual( + 3, + len(filtered_results), + "Should have 3 filtered answers after low score variation.", + ) + + async def test_should_answer_with_is_test_true(self): + options = QnAMakerOptions(top=1, is_test=True) + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + question: str = "Q11" + context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file( + "QnaMaker_IsTest_true.json" + ) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context, options=options) + self.assertEqual(0, len(results), "Should have received zero answer.") + + async def test_should_answer_with_ranker_type_question_only(self): + options = QnAMakerOptions(top=1, ranker_type="QuestionOnly") + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + question: str = "Q11" + context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file( + "QnaMaker_RankerType_QuestionOnly.json" + ) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(context, options=options) + self.assertEqual(2, len(results), "Should have received two answers.") + + async def test_should_answer_with_prompts(self): + options = QnAMakerOptions(top=2) + qna = QnAMaker(QnaApplicationTest.tests_endpoint, options) + question: str = "how do I clean the stove?" + turn_context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file("AnswerWithPrompts.json") + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(turn_context, options) + self.assertEqual(1, len(results), "Should have received 1 answers.") + self.assertEqual( + 1, len(results[0].context.prompts), "Should have received 1 prompt." + ) + + async def test_should_answer_with_high_score_provided_context(self): + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + question: str = "where can I buy?" + context = QnARequestContext( + previous_qna_id=5, prvious_user_query="how do I clean the stove?" + ) + options = QnAMakerOptions(top=2, qna_id=55, context=context) + turn_context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file( + "AnswerWithHighScoreProvidedContext.json" + ) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(turn_context, options) + self.assertEqual(1, len(results), "Should have received 1 answers.") + self.assertEqual(1, results[0].score, "Score should be high.") + + async def test_should_answer_with_high_score_provided_qna_id(self): + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + question: str = "where can I buy?" + + options = QnAMakerOptions(top=2, qna_id=55) + turn_context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file( + "AnswerWithHighScoreProvidedContext.json" + ) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(turn_context, options) + self.assertEqual(1, len(results), "Should have received 1 answers.") + self.assertEqual(1, results[0].score, "Score should be high.") + + async def test_should_answer_with_low_score_without_provided_context(self): + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + question: str = "where can I buy?" + options = QnAMakerOptions(top=2, context=None) + + turn_context = QnaApplicationTest._get_context(question, TestAdapter()) + response_json = QnaApplicationTest._get_json_for_file( + "AnswerWithLowScoreProvidedWithoutContext.json" + ) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + results = await qna.get_answers(turn_context, options) + self.assertEqual( + 2, len(results), "Should have received more than one answers." + ) + self.assertEqual(True, results[0].score < 1, "Score should be low.") + + @classmethod + async def _get_service_result( + cls, + utterance: str, + response_file: str, + bot_adapter: BotAdapter = TestAdapter(), + options: QnAMakerOptions = None, + ) -> [dict]: + response_json = QnaApplicationTest._get_json_for_file(response_file) + + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + context = QnaApplicationTest._get_context(utterance, bot_adapter) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + result = await qna.get_answers(context, options) + + return result + + @classmethod + async def _get_service_result_raw( + cls, + utterance: str, + response_file: str, + bot_adapter: BotAdapter = TestAdapter(), + options: QnAMakerOptions = None, + ) -> [dict]: + response_json = QnaApplicationTest._get_json_for_file(response_file) + + qna = QnAMaker(QnaApplicationTest.tests_endpoint) + context = QnaApplicationTest._get_context(utterance, bot_adapter) + + with patch( + "aiohttp.ClientSession.post", + return_value=aiounittest.futurized(response_json), + ): + result = await qna.get_answers_raw(context, options) + + return result + + @classmethod + def _get_json_for_file(cls, response_file: str) -> object: + curr_dir = path.dirname(path.abspath(__file__)) + response_path = path.join(curr_dir, "test_data", response_file) + + with open(response_path, "r", encoding="utf-8-sig") as file: + response_str = file.read() + response_json = json.loads(response_str) + + return response_json + + @staticmethod + def _get_context(question: str, bot_adapter: BotAdapter) -> TurnContext: + test_adapter = bot_adapter or TestAdapter() + activity = Activity( + type=ActivityTypes.message, + text=question, + conversation=ConversationAccount(), + recipient=ChannelAccount(), + from_property=ChannelAccount(), + ) + + return TurnContext(test_adapter, activity) + + class OverrideTelemetry(QnAMaker): + def __init__( # pylint: disable=useless-super-delegation + self, + endpoint: QnAMakerEndpoint, + options: QnAMakerOptions, + http_client: ClientSession, + telemetry_client: BotTelemetryClient, + log_personal_information: bool, + ): + super().__init__( + endpoint, + options, + http_client, + telemetry_client, + log_personal_information, + ) + + async def on_qna_result( # pylint: disable=unused-argument + self, + query_results: [QueryResult], + turn_context: TurnContext, + telemetry_properties: Dict[str, str] = None, + telemetry_metrics: Dict[str, float] = None, + ): + properties = telemetry_properties or {} + + # get_answers overrides derived class + properties["my_important_property"] = "my_important_value" + + # Log event + self.telemetry_client.track_event( + QnATelemetryConstants.qna_message_event, properties + ) + + # Create 2nd event. + second_event_properties = {"my_important_property2": "my_important_value2"} + self.telemetry_client.track_event( + "my_second_event", second_event_properties + ) + + class OverrideFillTelemetry(QnAMaker): + def __init__( # pylint: disable=useless-super-delegation + self, + endpoint: QnAMakerEndpoint, + options: QnAMakerOptions, + http_client: ClientSession, + telemetry_client: BotTelemetryClient, + log_personal_information: bool, + ): + super().__init__( + endpoint, + options, + http_client, + telemetry_client, + log_personal_information, + ) + + async def on_qna_result( + self, + query_results: [QueryResult], + turn_context: TurnContext, + telemetry_properties: Dict[str, str] = None, + telemetry_metrics: Dict[str, float] = None, + ): + event_data = await self.fill_qna_event( + query_results, turn_context, telemetry_properties, telemetry_metrics + ) + + # Add my property. + event_data.properties.update( + {"my_important_property": "my_important_value"} + ) + + # Log QnaMessage event. + self.telemetry_client.track_event( + QnATelemetryConstants.qna_message_event, + event_data.properties, + event_data.metrics, + ) + + # Create second event. + second_event_properties: Dict[str, str] = { + "my_important_property2": "my_important_value2" + } + + self.telemetry_client.track_event("MySecondEvent", second_event_properties) diff --git a/libraries/botbuilder-applicationinsights/tests/test_telemetry_waterfall.py b/libraries/botbuilder-applicationinsights/tests/test_telemetry_waterfall.py index c42adee2f..c1ab6e261 100644 --- a/libraries/botbuilder-applicationinsights/tests/test_telemetry_waterfall.py +++ b/libraries/botbuilder-applicationinsights/tests/test_telemetry_waterfall.py @@ -1,178 +1,174 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from unittest.mock import patch -from typing import Dict -import aiounittest -from botbuilder.core.adapters import TestAdapter, TestFlow -from botbuilder.schema import Activity -from botbuilder.core import ( - ConversationState, - MemoryStorage, - TurnContext, - NullTelemetryClient, -) -from botbuilder.dialogs import ( - Dialog, - DialogSet, - WaterfallDialog, - DialogTurnResult, - DialogTurnStatus, -) - -BEGIN_MESSAGE = Activity() -BEGIN_MESSAGE.text = "begin" -BEGIN_MESSAGE.type = "message" - - -class TelemetryWaterfallTests(aiounittest.AsyncTestCase): - def test_none_telemetry_client(self): - # arrange - dialog = WaterfallDialog("myId") - # act - dialog.telemetry_client = None - # assert - self.assertEqual(type(dialog.telemetry_client), NullTelemetryClient) - - @patch("botbuilder.applicationinsights.ApplicationInsightsTelemetryClient") - async def test_execute_sequence_waterfall_steps( # pylint: disable=invalid-name - self, MockTelemetry - ): - # arrange - - # Create new ConversationState with MemoryStorage and register the state as middleware. - convo_state = ConversationState(MemoryStorage()) - telemetry = MockTelemetry() - - # Create a DialogState property, DialogSet and register the WaterfallDialog. - dialog_state = convo_state.create_property("dialogState") - dialogs = DialogSet(dialog_state) - - async def step1(step) -> DialogTurnResult: - await step.context.send_activity("bot responding.") - return Dialog.end_of_turn - - async def step2(step) -> DialogTurnResult: - await step.context.send_activity("ending WaterfallDialog.") - return Dialog.end_of_turn - - # act - - my_dialog = WaterfallDialog("test", [step1, step2]) - my_dialog.telemetry_client = telemetry - dialogs.add(my_dialog) - - # Initialize TestAdapter - async def exec_test(turn_context: TurnContext) -> None: - - dialog_context = await dialogs.create_context(turn_context) - results = await dialog_context.continue_dialog() - if results.status == DialogTurnStatus.Empty: - await dialog_context.begin_dialog("test") - else: - if results.status == DialogTurnStatus.Complete: - await turn_context.send_activity(results.result) - - await convo_state.save_changes(turn_context) - - adapt = TestAdapter(exec_test) - - test_flow = TestFlow(None, adapt) - tf2 = await test_flow.send(BEGIN_MESSAGE) - tf3 = await tf2.assert_reply("bot responding.") - tf4 = await tf3.send("continue") - await tf4.assert_reply("ending WaterfallDialog.") - - # assert - - telemetry_calls = [ - ("WaterfallStart", {"DialogId": "test"}), - ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), - ("WaterfallStep", {"DialogId": "test", "StepName": "Step2of2"}), - ] - self.assert_telemetry_calls(telemetry, telemetry_calls) - - @patch("botbuilder.applicationinsights.ApplicationInsightsTelemetryClient") - async def test_ensure_end_dialog_called( - self, MockTelemetry - ): # pylint: disable=invalid-name - # arrange - - # Create new ConversationState with MemoryStorage and register the state as middleware. - convo_state = ConversationState(MemoryStorage()) - telemetry = MockTelemetry() - - # Create a DialogState property, DialogSet and register the WaterfallDialog. - dialog_state = convo_state.create_property("dialogState") - dialogs = DialogSet(dialog_state) - - async def step1(step) -> DialogTurnResult: - await step.context.send_activity("step1 response") - return Dialog.end_of_turn - - async def step2(step) -> DialogTurnResult: - await step.context.send_activity("step2 response") - return Dialog.end_of_turn - - # act - - my_dialog = WaterfallDialog("test", [step1, step2]) - my_dialog.telemetry_client = telemetry - dialogs.add(my_dialog) - - # Initialize TestAdapter - async def exec_test(turn_context: TurnContext) -> None: - - dialog_context = await dialogs.create_context(turn_context) - await dialog_context.continue_dialog() - if not turn_context.responded: - await dialog_context.begin_dialog("test", None) - await convo_state.save_changes(turn_context) - - adapt = TestAdapter(exec_test) - - test_flow = TestFlow(None, adapt) - tf2 = await test_flow.send(BEGIN_MESSAGE) - tf3 = await tf2.assert_reply("step1 response") - tf4 = await tf3.send("continue") - tf5 = await tf4.assert_reply("step2 response") - await tf5.send( - "Should hit end of steps - this will restart the dialog and trigger COMPLETE event" - ) - # assert - telemetry_calls = [ - ("WaterfallStart", {"DialogId": "test"}), - ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), - ("WaterfallStep", {"DialogId": "test", "StepName": "Step2of2"}), - ("WaterfallComplete", {"DialogId": "test"}), - ("WaterfallStart", {"DialogId": "test"}), - ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), - ] - print(str(telemetry.track_event.call_args_list)) - self.assert_telemetry_calls(telemetry, telemetry_calls) - - def assert_telemetry_call( - self, telemetry_mock, index: int, event_name: str, props: Dict[str, str] - ) -> None: - # pylint: disable=unused-variable - args, kwargs = telemetry_mock.track_event.call_args_list[index] - self.assertEqual(args[0], event_name) - - for key, val in props.items(): - self.assertTrue( - key in args[1], - msg=f"Could not find value {key} in {args[1]} for index {index}", - ) - self.assertTrue(isinstance(args[1], dict)) - self.assertTrue(val == args[1][key]) - - def assert_telemetry_calls(self, telemetry_mock, calls) -> None: - index = 0 - for event_name, props in calls: - self.assert_telemetry_call(telemetry_mock, index, event_name, props) - index += 1 - if index != len(telemetry_mock.track_event.call_args_list): - self.assertTrue( # pylint: disable=redundant-unittest-assert - False, - f"Found {len(telemetry_mock.track_event.call_args_list)} calls, testing for {index + 1}", - ) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from unittest.mock import MagicMock +from typing import Dict +import aiounittest +from botbuilder.core.adapters import TestAdapter, TestFlow +from botbuilder.schema import Activity +from botbuilder.core import ( + ConversationState, + MemoryStorage, + TurnContext, + NullTelemetryClient, +) +from botbuilder.dialogs import ( + Dialog, + DialogSet, + WaterfallDialog, + DialogTurnResult, + DialogTurnStatus, +) + +BEGIN_MESSAGE = Activity() +BEGIN_MESSAGE.text = "begin" +BEGIN_MESSAGE.type = "message" + +MOCK_TELEMETRY = "botbuilder.applicationinsights.ApplicationInsightsTelemetryClient" + + +class TelemetryWaterfallTests(aiounittest.AsyncTestCase): + def test_none_telemetry_client(self): + # arrange + dialog = WaterfallDialog("myId") + # act + dialog.telemetry_client = None + # assert + self.assertEqual(type(dialog.telemetry_client), NullTelemetryClient) + + async def test_execute_sequence_waterfall_steps(self): + # arrange + + # Create new ConversationState with MemoryStorage and register the state as middleware. + convo_state = ConversationState(MemoryStorage()) + telemetry = MagicMock(name=MOCK_TELEMETRY) + + # Create a DialogState property, DialogSet and register the WaterfallDialog. + dialog_state = convo_state.create_property("dialogState") + dialogs = DialogSet(dialog_state) + + async def step1(step) -> DialogTurnResult: + await step.context.send_activity("bot responding.") + return Dialog.end_of_turn + + async def step2(step) -> DialogTurnResult: + await step.context.send_activity("ending WaterfallDialog.") + return Dialog.end_of_turn + + # act + + my_dialog = WaterfallDialog("test", [step1, step2]) + my_dialog.telemetry_client = telemetry + dialogs.add(my_dialog) + + # Initialize TestAdapter + async def exec_test(turn_context: TurnContext) -> None: + + dialog_context = await dialogs.create_context(turn_context) + results = await dialog_context.continue_dialog() + if results.status == DialogTurnStatus.Empty: + await dialog_context.begin_dialog("test") + else: + if results.status == DialogTurnStatus.Complete: + await turn_context.send_activity(results.result) + + await convo_state.save_changes(turn_context) + + adapt = TestAdapter(exec_test) + + test_flow = TestFlow(None, adapt) + tf2 = await test_flow.send(BEGIN_MESSAGE) + tf3 = await tf2.assert_reply("bot responding.") + tf4 = await tf3.send("continue") + await tf4.assert_reply("ending WaterfallDialog.") + + # assert + + telemetry_calls = [ + ("WaterfallStart", {"DialogId": "test"}), + ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), + ("WaterfallStep", {"DialogId": "test", "StepName": "Step2of2"}), + ] + self.assert_telemetry_calls(telemetry, telemetry_calls) + + async def test_ensure_end_dialog_called(self): + # arrange + + # Create new ConversationState with MemoryStorage and register the state as middleware. + convo_state = ConversationState(MemoryStorage()) + telemetry = MagicMock(name=MOCK_TELEMETRY) + + # Create a DialogState property, DialogSet and register the WaterfallDialog. + dialog_state = convo_state.create_property("dialogState") + dialogs = DialogSet(dialog_state) + + async def step1(step) -> DialogTurnResult: + await step.context.send_activity("step1 response") + return Dialog.end_of_turn + + async def step2(step) -> DialogTurnResult: + await step.context.send_activity("step2 response") + return Dialog.end_of_turn + + # act + + my_dialog = WaterfallDialog("test", [step1, step2]) + my_dialog.telemetry_client = telemetry + dialogs.add(my_dialog) + + # Initialize TestAdapter + async def exec_test(turn_context: TurnContext) -> None: + + dialog_context = await dialogs.create_context(turn_context) + await dialog_context.continue_dialog() + if not turn_context.responded: + await dialog_context.begin_dialog("test", None) + await convo_state.save_changes(turn_context) + + adapt = TestAdapter(exec_test) + + test_flow = TestFlow(None, adapt) + tf2 = await test_flow.send(BEGIN_MESSAGE) + tf3 = await tf2.assert_reply("step1 response") + tf4 = await tf3.send("continue") + tf5 = await tf4.assert_reply("step2 response") + await tf5.send( + "Should hit end of steps - this will restart the dialog and trigger COMPLETE event" + ) + # assert + telemetry_calls = [ + ("WaterfallStart", {"DialogId": "test"}), + ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), + ("WaterfallStep", {"DialogId": "test", "StepName": "Step2of2"}), + ("WaterfallComplete", {"DialogId": "test"}), + ("WaterfallStart", {"DialogId": "test"}), + ("WaterfallStep", {"DialogId": "test", "StepName": "Step1of2"}), + ] + print(str(telemetry.track_event.call_args_list)) + self.assert_telemetry_calls(telemetry, telemetry_calls) + + def assert_telemetry_call( + self, telemetry_mock, index: int, event_name: str, props: Dict[str, str] + ) -> None: + # pylint: disable=unused-variable + args, kwargs = telemetry_mock.track_event.call_args_list[index] + self.assertEqual(args[0], event_name) + + for key, val in props.items(): + self.assertTrue( + key in args[1], + msg=f"Could not find value {key} in {args[1]} for index {index}", + ) + self.assertTrue(isinstance(args[1], dict)) + self.assertTrue(val == args[1][key]) + + def assert_telemetry_calls(self, telemetry_mock, calls) -> None: + index = 0 + for event_name, props in calls: + self.assert_telemetry_call(telemetry_mock, index, event_name, props) + index += 1 + if index != len(telemetry_mock.track_event.call_args_list): + self.assertTrue( # pylint: disable=redundant-unittest-assert + False, + f"Found {len(telemetry_mock.track_event.call_args_list)} calls, testing for {index + 1}", + ) diff --git a/libraries/botbuilder-core/botbuilder/core/adapters/test_adapter.py b/libraries/botbuilder-core/botbuilder/core/adapters/test_adapter.py index 0ff9f16b6..d8acd678c 100644 --- a/libraries/botbuilder-core/botbuilder/core/adapters/test_adapter.py +++ b/libraries/botbuilder-core/botbuilder/core/adapters/test_adapter.py @@ -1,463 +1,467 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -# TODO: enable this in the future -# With python 3.7 the line below will allow to do Postponed Evaluation of Annotations. See PEP 563 -# from __future__ import annotations - -import asyncio -import inspect -from datetime import datetime -from typing import Awaitable, Coroutine, Dict, List, Callable, Union -from copy import copy -from threading import Lock -from botbuilder.schema import ( - ActivityTypes, - Activity, - ConversationAccount, - ConversationReference, - ChannelAccount, - ResourceResponse, - TokenResponse, -) -from botframework.connector.auth import ClaimsIdentity -from ..bot_adapter import BotAdapter -from ..turn_context import TurnContext -from ..user_token_provider import UserTokenProvider - - -class UserToken: - def __init__( - self, - connection_name: str = None, - user_id: str = None, - channel_id: str = None, - token: str = None, - ): - self.connection_name = connection_name - self.user_id = user_id - self.channel_id = channel_id - self.token = token - - def equals_key(self, rhs: "UserToken"): - return ( - rhs is not None - and self.connection_name == rhs.connection_name - and self.user_id == rhs.user_id - and self.channel_id == rhs.channel_id - ) - - -class TokenMagicCode: - def __init__(self, key: UserToken = None, magic_code: str = None): - self.key = key - self.magic_code = magic_code - - -class TestAdapter(BotAdapter, UserTokenProvider): - def __init__( - self, - logic: Coroutine = None, - template_or_conversation: Union[Activity, ConversationReference] = None, - send_trace_activities: bool = False, - ): # pylint: disable=unused-argument - """ - Creates a new TestAdapter instance. - :param logic: - :param conversation: A reference to the conversation to begin the adapter state with. - """ - super(TestAdapter, self).__init__() - self.logic = logic - self._next_id: int = 0 - self._user_tokens: List[UserToken] = [] - self._magic_codes: List[TokenMagicCode] = [] - self._conversation_lock = Lock() - self.activity_buffer: List[Activity] = [] - self.updated_activities: List[Activity] = [] - self.deleted_activities: List[ConversationReference] = [] - self.send_trace_activities = send_trace_activities - - self.template = ( - template_or_conversation - if isinstance(template_or_conversation, Activity) - else Activity( - channel_id="test", - service_url="https://test.com", - from_property=ChannelAccount(id="User1", name="user"), - recipient=ChannelAccount(id="bot", name="Bot"), - conversation=ConversationAccount(id="Convo1"), - ) - ) - - if isinstance(template_or_conversation, ConversationReference): - self.template.channel_id = template_or_conversation.channel_id - - async def process_activity( - self, activity: Activity, logic: Callable[[TurnContext], Awaitable] - ): - self._conversation_lock.acquire() - try: - # ready for next reply - if activity.type is None: - activity.type = ActivityTypes.message - - activity.channel_id = self.template.channel_id - activity.from_property = self.template.from_property - activity.recipient = self.template.recipient - activity.conversation = self.template.conversation - activity.service_url = self.template.service_url - - activity.id = str((self._next_id)) - self._next_id += 1 - finally: - self._conversation_lock.release() - - activity.timestamp = activity.timestamp or datetime.utcnow() - await self.run_pipeline(TurnContext(self, activity), logic) - - async def send_activities( - self, context, activities: List[Activity] - ) -> List[ResourceResponse]: - """ - INTERNAL: called by the logic under test to send a set of activities. These will be buffered - to the current `TestFlow` instance for comparison against the expected results. - :param context: - :param activities: - :return: - """ - - def id_mapper(activity): - self.activity_buffer.append(activity) - self._next_id += 1 - return ResourceResponse(id=str(self._next_id)) - - return [ - id_mapper(activity) - for activity in activities - if self.send_trace_activities or activity.type != "trace" - ] - - async def delete_activity(self, context, reference: ConversationReference): - """ - INTERNAL: called by the logic under test to delete an existing activity. These are simply - pushed onto a [deletedActivities](#deletedactivities) array for inspection after the turn - completes. - :param reference: - :return: - """ - self.deleted_activities.append(reference) - - async def update_activity(self, context, activity: Activity): - """ - INTERNAL: called by the logic under test to replace an existing activity. These are simply - pushed onto an [updatedActivities](#updatedactivities) array for inspection after the turn - completes. - :param activity: - :return: - """ - self.updated_activities.append(activity) - - async def continue_conversation( - self, - reference: ConversationReference, - callback: Callable, - bot_id: str = None, - claims_identity: ClaimsIdentity = None, # pylint: disable=unused-argument - ): - """ - The `TestAdapter` just calls parent implementation. - :param reference: - :param callback: - :param bot_id: - :param claims_identity: - :return: - """ - await super().continue_conversation( - reference, callback, bot_id, claims_identity - ) - - async def receive_activity(self, activity): - """ - INTERNAL: called by a `TestFlow` instance to simulate a user sending a message to the bot. - This will cause the adapters middleware pipe to be run and it's logic to be called. - :param activity: - :return: - """ - if isinstance(activity, str): - activity = Activity(type="message", text=activity) - # Initialize request. - request = copy(self.template) - - for key, value in vars(activity).items(): - if value is not None and key != "additional_properties": - setattr(request, key, value) - - request.type = request.type or ActivityTypes.message - if not request.id: - self._next_id += 1 - request.id = str(self._next_id) - - # Create context object and run middleware. - context = TurnContext(self, request) - return await self.run_pipeline(context, self.logic) - - def get_next_activity(self) -> Activity: - return self.activity_buffer.pop(0) - - async def send(self, user_says) -> object: - """ - Sends something to the bot. This returns a new `TestFlow` instance which can be used to add - additional steps for inspecting the bots reply and then sending additional activities. - :param user_says: - :return: A new instance of the TestFlow object - """ - return TestFlow(await self.receive_activity(user_says), self) - - async def test( - self, user_says, expected, description=None, timeout=None - ) -> "TestFlow": - """ - Send something to the bot and expects the bot to return with a given reply. This is simply a - wrapper around calls to `send()` and `assertReply()`. This is such a common pattern that a - helper is provided. - :param user_says: - :param expected: - :param description: - :param timeout: - :return: - """ - test_flow = await self.send(user_says) - test_flow = await test_flow.assert_reply(expected, description, timeout) - return test_flow - - async def tests(self, *args): - """ - Support multiple test cases without having to manually call `test()` repeatedly. This is a - convenience layer around the `test()`. Valid args are either lists or tuples of parameters - :param args: - :return: - """ - for arg in args: - description = None - timeout = None - if len(arg) >= 3: - description = arg[2] - if len(arg) == 4: - timeout = arg[3] - await self.test(arg[0], arg[1], description, timeout) - - def add_user_token( - self, - connection_name: str, - channel_id: str, - user_id: str, - token: str, - magic_code: str = None, - ): - key = UserToken() - key.channel_id = channel_id - key.connection_name = connection_name - key.user_id = user_id - key.token = token - - if not magic_code: - self._user_tokens.append(key) - else: - code = TokenMagicCode() - code.key = key - code.magic_code = magic_code - self._magic_codes.append(code) - - async def get_user_token( - self, context: TurnContext, connection_name: str, magic_code: str = None - ) -> TokenResponse: - key = UserToken() - key.channel_id = context.activity.channel_id - key.connection_name = connection_name - key.user_id = context.activity.from_property.id - - if magic_code: - magic_code_record = list( - filter(lambda x: key.equals_key(x.key), self._magic_codes) - ) - if magic_code_record and magic_code_record[0].magic_code == magic_code: - # Move the token to long term dictionary. - self.add_user_token( - connection_name, - key.channel_id, - key.user_id, - magic_code_record[0].key.token, - ) - - # Remove from the magic code list. - idx = self._magic_codes.index(magic_code_record[0]) - self._magic_codes = [self._magic_codes.pop(idx)] - - match = [token for token in self._user_tokens if key.equals_key(token)] - - if match: - return TokenResponse( - connection_name=match[0].connection_name, - token=match[0].token, - expiration=None, - ) - # Not found. - return None - - async def sign_out_user( - self, context: TurnContext, connection_name: str, user_id: str = None - ): - channel_id = context.activity.channel_id - user_id = context.activity.from_property.id - - new_records = [] - for token in self._user_tokens: - if ( - token.channel_id != channel_id - or token.user_id != user_id - or (connection_name and connection_name != token.connection_name) - ): - new_records.append(token) - self._user_tokens = new_records - - async def get_oauth_sign_in_link( - self, context: TurnContext, connection_name: str - ) -> str: - return ( - f"https://fake.com/oauthsignin" - f"/{connection_name}/{context.activity.channel_id}/{context.activity.from_property.id}" - ) - - async def get_aad_tokens( - self, context: TurnContext, connection_name: str, resource_urls: List[str] - ) -> Dict[str, TokenResponse]: - return None - - -class TestFlow: - def __init__(self, previous: Callable, adapter: TestAdapter): - """ - INTERNAL: creates a new TestFlow instance. - :param previous: - :param adapter: - """ - self.previous = previous - self.adapter = adapter - - async def test( - self, user_says, expected, description=None, timeout=None - ) -> "TestFlow": - """ - Send something to the bot and expects the bot to return with a given reply. This is simply a - wrapper around calls to `send()` and `assertReply()`. This is such a common pattern that a - helper is provided. - :param user_says: - :param expected: - :param description: - :param timeout: - :return: - """ - test_flow = await self.send(user_says) - return await test_flow.assert_reply( - expected, description or f'test("{user_says}", "{expected}")', timeout - ) - - async def send(self, user_says) -> "TestFlow": - """ - Sends something to the bot. - :param user_says: - :return: - """ - - async def new_previous(): - nonlocal self, user_says - if callable(self.previous): - await self.previous() - await self.adapter.receive_activity(user_says) - - return TestFlow(await new_previous(), self.adapter) - - async def assert_reply( - self, - expected: Union[str, Activity, Callable[[Activity, str], None]], - description=None, - timeout=None, # pylint: disable=unused-argument - is_substring=False, - ) -> "TestFlow": - """ - Generates an assertion if the bots response doesn't match the expected text/activity. - :param expected: - :param description: - :param timeout: - :param is_substring: - :return: - """ - # TODO: refactor method so expected can take a Callable[[Activity], None] - def default_inspector(reply, description=None): - if isinstance(expected, Activity): - validate_activity(reply, expected) - else: - assert reply.type == "message", description + f" type == {reply.type}" - if is_substring: - assert expected in reply.text.strip(), ( - description + f" text == {reply.text}" - ) - else: - assert reply.text.strip() == expected.strip(), ( - description + f" text == {reply.text}" - ) - - if description is None: - description = "" - - inspector = expected if callable(expected) else default_inspector - - async def test_flow_previous(): - nonlocal timeout - if not timeout: - timeout = 3000 - start = datetime.now() - adapter = self.adapter - - async def wait_for_activity(): - nonlocal expected, timeout - current = datetime.now() - if (current - start).total_seconds() * 1000 > timeout: - if isinstance(expected, Activity): - expecting = expected.text - elif callable(expected): - expecting = inspect.getsourcefile(expected) - else: - expecting = str(expected) - raise RuntimeError( - f"TestAdapter.assert_reply({expecting}): {description} Timed out after " - f"{current - start}ms." - ) - if adapter.activity_buffer: - reply = adapter.activity_buffer.pop(0) - try: - await inspector(reply, description) - except Exception: - inspector(reply, description) - - else: - await asyncio.sleep(0.05) - await wait_for_activity() - - await wait_for_activity() - - return TestFlow(await test_flow_previous(), self.adapter) - - -def validate_activity(activity, expected) -> None: - """ - Helper method that compares activities - :param activity: - :param expected: - :return: - """ - iterable_expected = vars(expected).items() - - for attr, value in iterable_expected: - if value is not None and attr != "additional_properties": - assert value == getattr(activity, attr) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# TODO: enable this in the future +# With python 3.7 the line below will allow to do Postponed Evaluation of Annotations. See PEP 563 +# from __future__ import annotations + +import asyncio +import inspect +from datetime import datetime +from typing import Awaitable, Coroutine, Dict, List, Callable, Union +from copy import copy +from threading import Lock +from botbuilder.schema import ( + ActivityTypes, + Activity, + ConversationAccount, + ConversationReference, + ChannelAccount, + ResourceResponse, + TokenResponse, +) +from botframework.connector.auth import ClaimsIdentity +from ..bot_adapter import BotAdapter +from ..turn_context import TurnContext +from ..user_token_provider import UserTokenProvider + + +class UserToken: + def __init__( + self, + connection_name: str = None, + user_id: str = None, + channel_id: str = None, + token: str = None, + ): + self.connection_name = connection_name + self.user_id = user_id + self.channel_id = channel_id + self.token = token + + def equals_key(self, rhs: "UserToken"): + return ( + rhs is not None + and self.connection_name == rhs.connection_name + and self.user_id == rhs.user_id + and self.channel_id == rhs.channel_id + ) + + +class TokenMagicCode: + def __init__(self, key: UserToken = None, magic_code: str = None): + self.key = key + self.magic_code = magic_code + + +class TestAdapter(BotAdapter, UserTokenProvider): + __test__ = False + + def __init__( + self, + logic: Coroutine = None, + template_or_conversation: Union[Activity, ConversationReference] = None, + send_trace_activities: bool = False, + ): # pylint: disable=unused-argument + """ + Creates a new TestAdapter instance. + :param logic: + :param conversation: A reference to the conversation to begin the adapter state with. + """ + super(TestAdapter, self).__init__() + self.logic = logic + self._next_id: int = 0 + self._user_tokens: List[UserToken] = [] + self._magic_codes: List[TokenMagicCode] = [] + self._conversation_lock = Lock() + self.activity_buffer: List[Activity] = [] + self.updated_activities: List[Activity] = [] + self.deleted_activities: List[ConversationReference] = [] + self.send_trace_activities = send_trace_activities + + self.template = ( + template_or_conversation + if isinstance(template_or_conversation, Activity) + else Activity( + channel_id="test", + service_url="https://test.com", + from_property=ChannelAccount(id="User1", name="user"), + recipient=ChannelAccount(id="bot", name="Bot"), + conversation=ConversationAccount(id="Convo1"), + ) + ) + + if isinstance(template_or_conversation, ConversationReference): + self.template.channel_id = template_or_conversation.channel_id + + async def process_activity( + self, activity: Activity, logic: Callable[[TurnContext], Awaitable] + ): + self._conversation_lock.acquire() + try: + # ready for next reply + if activity.type is None: + activity.type = ActivityTypes.message + + activity.channel_id = self.template.channel_id + activity.from_property = self.template.from_property + activity.recipient = self.template.recipient + activity.conversation = self.template.conversation + activity.service_url = self.template.service_url + + activity.id = str((self._next_id)) + self._next_id += 1 + finally: + self._conversation_lock.release() + + activity.timestamp = activity.timestamp or datetime.utcnow() + await self.run_pipeline(TurnContext(self, activity), logic) + + async def send_activities( + self, context, activities: List[Activity] + ) -> List[ResourceResponse]: + """ + INTERNAL: called by the logic under test to send a set of activities. These will be buffered + to the current `TestFlow` instance for comparison against the expected results. + :param context: + :param activities: + :return: + """ + + def id_mapper(activity): + self.activity_buffer.append(activity) + self._next_id += 1 + return ResourceResponse(id=str(self._next_id)) + + return [ + id_mapper(activity) + for activity in activities + if self.send_trace_activities or activity.type != "trace" + ] + + async def delete_activity(self, context, reference: ConversationReference): + """ + INTERNAL: called by the logic under test to delete an existing activity. These are simply + pushed onto a [deletedActivities](#deletedactivities) array for inspection after the turn + completes. + :param reference: + :return: + """ + self.deleted_activities.append(reference) + + async def update_activity(self, context, activity: Activity): + """ + INTERNAL: called by the logic under test to replace an existing activity. These are simply + pushed onto an [updatedActivities](#updatedactivities) array for inspection after the turn + completes. + :param activity: + :return: + """ + self.updated_activities.append(activity) + + async def continue_conversation( + self, + reference: ConversationReference, + callback: Callable, + bot_id: str = None, + claims_identity: ClaimsIdentity = None, # pylint: disable=unused-argument + ): + """ + The `TestAdapter` just calls parent implementation. + :param reference: + :param callback: + :param bot_id: + :param claims_identity: + :return: + """ + await super().continue_conversation( + reference, callback, bot_id, claims_identity + ) + + async def receive_activity(self, activity): + """ + INTERNAL: called by a `TestFlow` instance to simulate a user sending a message to the bot. + This will cause the adapters middleware pipe to be run and it's logic to be called. + :param activity: + :return: + """ + if isinstance(activity, str): + activity = Activity(type="message", text=activity) + # Initialize request. + request = copy(self.template) + + for key, value in vars(activity).items(): + if value is not None and key != "additional_properties": + setattr(request, key, value) + + request.type = request.type or ActivityTypes.message + if not request.id: + self._next_id += 1 + request.id = str(self._next_id) + + # Create context object and run middleware. + context = TurnContext(self, request) + return await self.run_pipeline(context, self.logic) + + def get_next_activity(self) -> Activity: + return self.activity_buffer.pop(0) + + async def send(self, user_says) -> object: + """ + Sends something to the bot. This returns a new `TestFlow` instance which can be used to add + additional steps for inspecting the bots reply and then sending additional activities. + :param user_says: + :return: A new instance of the TestFlow object + """ + return TestFlow(await self.receive_activity(user_says), self) + + async def test( + self, user_says, expected, description=None, timeout=None + ) -> "TestFlow": + """ + Send something to the bot and expects the bot to return with a given reply. This is simply a + wrapper around calls to `send()` and `assertReply()`. This is such a common pattern that a + helper is provided. + :param user_says: + :param expected: + :param description: + :param timeout: + :return: + """ + test_flow = await self.send(user_says) + test_flow = await test_flow.assert_reply(expected, description, timeout) + return test_flow + + async def tests(self, *args): + """ + Support multiple test cases without having to manually call `test()` repeatedly. This is a + convenience layer around the `test()`. Valid args are either lists or tuples of parameters + :param args: + :return: + """ + for arg in args: + description = None + timeout = None + if len(arg) >= 3: + description = arg[2] + if len(arg) == 4: + timeout = arg[3] + await self.test(arg[0], arg[1], description, timeout) + + def add_user_token( + self, + connection_name: str, + channel_id: str, + user_id: str, + token: str, + magic_code: str = None, + ): + key = UserToken() + key.channel_id = channel_id + key.connection_name = connection_name + key.user_id = user_id + key.token = token + + if not magic_code: + self._user_tokens.append(key) + else: + code = TokenMagicCode() + code.key = key + code.magic_code = magic_code + self._magic_codes.append(code) + + async def get_user_token( + self, context: TurnContext, connection_name: str, magic_code: str = None + ) -> TokenResponse: + key = UserToken() + key.channel_id = context.activity.channel_id + key.connection_name = connection_name + key.user_id = context.activity.from_property.id + + if magic_code: + magic_code_record = list( + filter(lambda x: key.equals_key(x.key), self._magic_codes) + ) + if magic_code_record and magic_code_record[0].magic_code == magic_code: + # Move the token to long term dictionary. + self.add_user_token( + connection_name, + key.channel_id, + key.user_id, + magic_code_record[0].key.token, + ) + + # Remove from the magic code list. + idx = self._magic_codes.index(magic_code_record[0]) + self._magic_codes = [self._magic_codes.pop(idx)] + + match = [token for token in self._user_tokens if key.equals_key(token)] + + if match: + return TokenResponse( + connection_name=match[0].connection_name, + token=match[0].token, + expiration=None, + ) + # Not found. + return None + + async def sign_out_user( + self, context: TurnContext, connection_name: str, user_id: str = None + ): + channel_id = context.activity.channel_id + user_id = context.activity.from_property.id + + new_records = [] + for token in self._user_tokens: + if ( + token.channel_id != channel_id + or token.user_id != user_id + or (connection_name and connection_name != token.connection_name) + ): + new_records.append(token) + self._user_tokens = new_records + + async def get_oauth_sign_in_link( + self, context: TurnContext, connection_name: str + ) -> str: + return ( + f"https://fake.com/oauthsignin" + f"/{connection_name}/{context.activity.channel_id}/{context.activity.from_property.id}" + ) + + async def get_aad_tokens( + self, context: TurnContext, connection_name: str, resource_urls: List[str] + ) -> Dict[str, TokenResponse]: + return None + + +class TestFlow: + __test__ = False + + def __init__(self, previous: Callable, adapter: TestAdapter): + """ + INTERNAL: creates a new TestFlow instance. + :param previous: + :param adapter: + """ + self.previous = previous + self.adapter = adapter + + async def test( + self, user_says, expected, description=None, timeout=None + ) -> "TestFlow": + """ + Send something to the bot and expects the bot to return with a given reply. This is simply a + wrapper around calls to `send()` and `assertReply()`. This is such a common pattern that a + helper is provided. + :param user_says: + :param expected: + :param description: + :param timeout: + :return: + """ + test_flow = await self.send(user_says) + return await test_flow.assert_reply( + expected, description or f'test("{user_says}", "{expected}")', timeout + ) + + async def send(self, user_says) -> "TestFlow": + """ + Sends something to the bot. + :param user_says: + :return: + """ + + async def new_previous(): + nonlocal self, user_says + if callable(self.previous): + await self.previous() + await self.adapter.receive_activity(user_says) + + return TestFlow(await new_previous(), self.adapter) + + async def assert_reply( + self, + expected: Union[str, Activity, Callable[[Activity, str], None]], + description=None, + timeout=None, # pylint: disable=unused-argument + is_substring=False, + ) -> "TestFlow": + """ + Generates an assertion if the bots response doesn't match the expected text/activity. + :param expected: + :param description: + :param timeout: + :param is_substring: + :return: + """ + # TODO: refactor method so expected can take a Callable[[Activity], None] + def default_inspector(reply, description=None): + if isinstance(expected, Activity): + validate_activity(reply, expected) + else: + assert reply.type == "message", description + f" type == {reply.type}" + if is_substring: + assert expected in reply.text.strip(), ( + description + f" text == {reply.text}" + ) + else: + assert reply.text.strip() == expected.strip(), ( + description + f" text == {reply.text}" + ) + + if description is None: + description = "" + + inspector = expected if callable(expected) else default_inspector + + async def test_flow_previous(): + nonlocal timeout + if not timeout: + timeout = 3000 + start = datetime.now() + adapter = self.adapter + + async def wait_for_activity(): + nonlocal expected, timeout + current = datetime.now() + if (current - start).total_seconds() * 1000 > timeout: + if isinstance(expected, Activity): + expecting = expected.text + elif callable(expected): + expecting = inspect.getsourcefile(expected) + else: + expecting = str(expected) + raise RuntimeError( + f"TestAdapter.assert_reply({expecting}): {description} Timed out after " + f"{current - start}ms." + ) + if adapter.activity_buffer: + reply = adapter.activity_buffer.pop(0) + try: + await inspector(reply, description) + except Exception: + inspector(reply, description) + + else: + await asyncio.sleep(0.05) + await wait_for_activity() + + await wait_for_activity() + + return TestFlow(await test_flow_previous(), self.adapter) + + +def validate_activity(activity, expected) -> None: + """ + Helper method that compares activities + :param activity: + :param expected: + :return: + """ + iterable_expected = vars(expected).items() + + for attr, value in iterable_expected: + if value is not None and attr != "additional_properties": + assert value == getattr(activity, attr) diff --git a/libraries/botbuilder-core/botbuilder/core/show_typing_middleware.py b/libraries/botbuilder-core/botbuilder/core/show_typing_middleware.py index ea10e3fac..6002fbcc7 100644 --- a/libraries/botbuilder-core/botbuilder/core/show_typing_middleware.py +++ b/libraries/botbuilder-core/botbuilder/core/show_typing_middleware.py @@ -1,95 +1,95 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import time -from functools import wraps -from typing import Awaitable, Callable - -from botbuilder.schema import Activity, ActivityTypes - -from .middleware_set import Middleware -from .turn_context import TurnContext - - -def delay(span=0.0): - def wrap(func): - @wraps(func) - async def delayed(): - time.sleep(span) - await func() - - return delayed - - return wrap - - -class Timer: - clear_timer = False - - async def set_timeout(self, func, time): - is_invocation_cancelled = False - - @delay(time) - async def some_fn(): # pylint: disable=function-redefined - if not self.clear_timer: - await func() - - await some_fn() - return is_invocation_cancelled - - def set_clear_timer(self): - self.clear_timer = True - - -class ShowTypingMiddleware(Middleware): - def __init__(self, delay: float = 0.5, period: float = 2.0): - if delay < 0: - raise ValueError("Delay must be greater than or equal to zero") - - if period <= 0: - raise ValueError("Repeat period must be greater than zero") - - self._delay = delay - self._period = period - - async def on_turn( - self, context: TurnContext, logic: Callable[[TurnContext], Awaitable] - ): - finished = False - timer = Timer() - - async def start_interval(context: TurnContext, delay: int, period: int): - async def aux(): - if not finished: - typing_activity = Activity( - type=ActivityTypes.typing, - relates_to=context.activity.relates_to, - ) - - conversation_reference = TurnContext.get_conversation_reference( - context.activity - ) - - typing_activity = TurnContext.apply_conversation_reference( - typing_activity, conversation_reference - ) - - await context.adapter.send_activities(context, [typing_activity]) - - start_interval(context, period, period) - - await timer.set_timeout(aux, delay) - - def stop_interval(): - nonlocal finished - finished = True - timer.set_clear_timer() - - if context.activity.type == ActivityTypes.message: - finished = False - await start_interval(context, self._delay, self._period) - - result = await logic() - stop_interval() - - return result +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import time +from functools import wraps +from typing import Awaitable, Callable + +from botbuilder.schema import Activity, ActivityTypes + +from .middleware_set import Middleware +from .turn_context import TurnContext + + +def delay(span=0.0): + def wrap(func): + @wraps(func) + async def delayed(): + time.sleep(span) + await func() + + return delayed + + return wrap + + +class Timer: + clear_timer = False + + async def set_timeout(self, func, time): + is_invocation_cancelled = False + + @delay(time) + async def some_fn(): # pylint: disable=function-redefined + if not self.clear_timer: + await func() + + await some_fn() + return is_invocation_cancelled + + def set_clear_timer(self): + self.clear_timer = True + + +class ShowTypingMiddleware(Middleware): + def __init__(self, delay: float = 0.5, period: float = 2.0): + if delay < 0: + raise ValueError("Delay must be greater than or equal to zero") + + if period <= 0: + raise ValueError("Repeat period must be greater than zero") + + self._delay = delay + self._period = period + + async def on_turn( + self, context: TurnContext, logic: Callable[[TurnContext], Awaitable] + ): + finished = False + timer = Timer() + + async def start_interval(context: TurnContext, delay: int, period: int): + async def aux(): + if not finished: + typing_activity = Activity( + type=ActivityTypes.typing, + relates_to=context.activity.relates_to, + ) + + conversation_reference = TurnContext.get_conversation_reference( + context.activity + ) + + typing_activity = TurnContext.apply_conversation_reference( + typing_activity, conversation_reference + ) + + await context.adapter.send_activities(context, [typing_activity]) + + start_interval(context, period, period) + + await timer.set_timeout(aux, delay) + + def stop_interval(): + nonlocal finished + finished = True + timer.set_clear_timer() + + if context.activity.type == ActivityTypes.message: + finished = False + await start_interval(context, self._delay, self._period) + + result = await logic() + stop_interval() + + return result diff --git a/libraries/botbuilder-core/tests/simple_adapter.py b/libraries/botbuilder-core/tests/simple_adapter.py index a80fa29b3..1202ad7f1 100644 --- a/libraries/botbuilder-core/tests/simple_adapter.py +++ b/libraries/botbuilder-core/tests/simple_adapter.py @@ -1,60 +1,60 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest -from typing import List -from botbuilder.core import BotAdapter, TurnContext -from botbuilder.schema import Activity, ConversationReference, ResourceResponse - - -class SimpleAdapter(BotAdapter): - # pylint: disable=unused-argument - - def __init__(self, call_on_send=None, call_on_update=None, call_on_delete=None): - super(SimpleAdapter, self).__init__() - self.test_aux = unittest.TestCase("__init__") - self._call_on_send = call_on_send - self._call_on_update = call_on_update - self._call_on_delete = call_on_delete - - async def delete_activity( - self, context: TurnContext, reference: ConversationReference - ): - self.test_aux.assertIsNotNone( - reference, "SimpleAdapter.delete_activity: missing reference" - ) - if self._call_on_delete is not None: - self._call_on_delete(reference) - - async def send_activities( - self, context: TurnContext, activities: List[Activity] - ) -> List[ResourceResponse]: - self.test_aux.assertIsNotNone( - activities, "SimpleAdapter.delete_activity: missing reference" - ) - self.test_aux.assertTrue( - len(activities) > 0, - "SimpleAdapter.send_activities: empty activities array.", - ) - - if self._call_on_send is not None: - self._call_on_send(activities) - responses = [] - - for activity in activities: - responses.append(ResourceResponse(id=activity.id)) - - return responses - - async def update_activity(self, context: TurnContext, activity: Activity): - self.test_aux.assertIsNotNone( - activity, "SimpleAdapter.update_activity: missing activity" - ) - if self._call_on_update is not None: - self._call_on_update(activity) - - return ResourceResponse(activity.id) - - async def process_request(self, activity, handler): - context = TurnContext(self, activity) - return self.run_pipeline(context, handler) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest +from typing import List +from botbuilder.core import BotAdapter, TurnContext +from botbuilder.schema import Activity, ConversationReference, ResourceResponse + + +class SimpleAdapter(BotAdapter): + # pylint: disable=unused-argument + + def __init__(self, call_on_send=None, call_on_update=None, call_on_delete=None): + super(SimpleAdapter, self).__init__() + self.test_aux = unittest.TestCase("__init__") + self._call_on_send = call_on_send + self._call_on_update = call_on_update + self._call_on_delete = call_on_delete + + async def delete_activity( + self, context: TurnContext, reference: ConversationReference + ): + self.test_aux.assertIsNotNone( + reference, "SimpleAdapter.delete_activity: missing reference" + ) + if self._call_on_delete is not None: + self._call_on_delete(reference) + + async def send_activities( + self, context: TurnContext, activities: List[Activity] + ) -> List[ResourceResponse]: + self.test_aux.assertIsNotNone( + activities, "SimpleAdapter.delete_activity: missing reference" + ) + self.test_aux.assertTrue( + len(activities) > 0, + "SimpleAdapter.send_activities: empty activities array.", + ) + + if self._call_on_send is not None: + self._call_on_send(activities) + responses = [] + + for activity in activities: + responses.append(ResourceResponse(id=activity.id)) + + return responses + + async def update_activity(self, context: TurnContext, activity: Activity): + self.test_aux.assertIsNotNone( + activity, "SimpleAdapter.update_activity: missing activity" + ) + if self._call_on_update is not None: + self._call_on_update(activity) + + return ResourceResponse(activity.id) + + async def process_request(self, activity, handler): + context = TurnContext(self, activity) + return await self.run_pipeline(context, handler) diff --git a/libraries/botbuilder-core/tests/teams/test_teams_activity_handler.py b/libraries/botbuilder-core/tests/teams/test_teams_activity_handler.py index 8de03909c..38c4e2c14 100644 --- a/libraries/botbuilder-core/tests/teams/test_teams_activity_handler.py +++ b/libraries/botbuilder-core/tests/teams/test_teams_activity_handler.py @@ -1,724 +1,726 @@ -from typing import List - -import aiounittest -from botbuilder.core import BotAdapter, TurnContext -from botbuilder.core.teams import TeamsActivityHandler -from botbuilder.schema import ( - Activity, - ActivityTypes, - ChannelAccount, - ConversationReference, - ResourceResponse, -) -from botbuilder.schema.teams import ( - AppBasedLinkQuery, - ChannelInfo, - FileConsentCardResponse, - MessageActionsPayload, - MessagingExtensionAction, - MessagingExtensionQuery, - O365ConnectorCardActionQuery, - TaskModuleRequest, - TaskModuleRequestContext, - TeamInfo, - TeamsChannelAccount, -) -from botframework.connector import Channels -from simple_adapter import SimpleAdapter - - -class TestingTeamsActivityHandler(TeamsActivityHandler): - def __init__(self): - self.record: List[str] = [] - - async def on_conversation_update_activity(self, turn_context: TurnContext): - self.record.append("on_conversation_update_activity") - return await super().on_conversation_update_activity(turn_context) - - async def on_teams_members_removed( - self, teams_members_removed: [TeamsChannelAccount], turn_context: TurnContext - ): - self.record.append("on_teams_members_removed") - return await super().on_teams_members_removed( - teams_members_removed, turn_context - ) - - async def on_message_activity(self, turn_context: TurnContext): - self.record.append("on_message_activity") - return await super().on_message_activity(turn_context) - - async def on_token_response_event(self, turn_context: TurnContext): - self.record.append("on_token_response_event") - return await super().on_token_response_event(turn_context) - - async def on_event(self, turn_context: TurnContext): - self.record.append("on_event") - return await super().on_event(turn_context) - - async def on_unrecognized_activity_type(self, turn_context: TurnContext): - self.record.append("on_unrecognized_activity_type") - return await super().on_unrecognized_activity_type(turn_context) - - async def on_teams_channel_created( - self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext - ): - self.record.append("on_teams_channel_created") - return await super().on_teams_channel_created( - channel_info, team_info, turn_context - ) - - async def on_teams_channel_renamed( - self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext - ): - self.record.append("on_teams_channel_renamed") - return await super().on_teams_channel_renamed( - channel_info, team_info, turn_context - ) - - async def on_teams_channel_deleted( - self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext - ): - self.record.append("on_teams_channel_deleted") - return await super().on_teams_channel_renamed( - channel_info, team_info, turn_context - ) - - async def on_teams_team_renamed_activity( - self, team_info: TeamInfo, turn_context: TurnContext - ): - self.record.append("on_teams_team_renamed_activity") - return await super().on_teams_team_renamed_activity(team_info, turn_context) - - async def on_invoke_activity(self, turn_context: TurnContext): - self.record.append("on_invoke_activity") - return await super().on_invoke_activity(turn_context) - - async def on_teams_signin_verify_state(self, turn_context: TurnContext): - self.record.append("on_teams_signin_verify_state") - return await super().on_teams_signin_verify_state(turn_context) - - async def on_teams_file_consent( - self, - turn_context: TurnContext, - file_consent_card_response: FileConsentCardResponse, - ): - self.record.append("on_teams_file_consent") - return await super().on_teams_file_consent( - turn_context, file_consent_card_response - ) - - async def on_teams_file_consent_accept( - self, - turn_context: TurnContext, - file_consent_card_response: FileConsentCardResponse, - ): - self.record.append("on_teams_file_consent_accept") - return await super().on_teams_file_consent_accept( - turn_context, file_consent_card_response - ) - - async def on_teams_file_consent_decline( - self, - turn_context: TurnContext, - file_consent_card_response: FileConsentCardResponse, - ): - self.record.append("on_teams_file_consent_decline") - return await super().on_teams_file_consent_decline( - turn_context, file_consent_card_response - ) - - async def on_teams_o365_connector_card_action( - self, turn_context: TurnContext, query: O365ConnectorCardActionQuery - ): - self.record.append("on_teams_o365_connector_card_action") - return await super().on_teams_o365_connector_card_action(turn_context, query) - - async def on_teams_app_based_link_query( - self, turn_context: TurnContext, query: AppBasedLinkQuery - ): - self.record.append("on_teams_app_based_link_query") - return await super().on_teams_app_based_link_query(turn_context, query) - - async def on_teams_messaging_extension_query( - self, turn_context: TurnContext, query: MessagingExtensionQuery - ): - self.record.append("on_teams_messaging_extension_query") - return await super().on_teams_messaging_extension_query(turn_context, query) - - async def on_teams_messaging_extension_submit_action_dispatch( - self, turn_context: TurnContext, action: MessagingExtensionAction - ): - self.record.append("on_teams_messaging_extension_submit_action_dispatch") - return await super().on_teams_messaging_extension_submit_action_dispatch( - turn_context, action - ) - - async def on_teams_messaging_extension_submit_action( - self, turn_context: TurnContext, action: MessagingExtensionAction - ): - self.record.append("on_teams_messaging_extension_submit_action") - return await super().on_teams_messaging_extension_submit_action( - turn_context, action - ) - - async def on_teams_messaging_extension_bot_message_preview_edit( - self, turn_context: TurnContext, action: MessagingExtensionAction - ): - self.record.append("on_teams_messaging_extension_bot_message_preview_edit") - return await super().on_teams_messaging_extension_bot_message_preview_edit( - turn_context, action - ) - - async def on_teams_messaging_extension_bot_message_preview_send( - self, turn_context: TurnContext, action: MessagingExtensionAction - ): - self.record.append("on_teams_messaging_extension_bot_message_preview_send") - return await super().on_teams_messaging_extension_bot_message_preview_send( - turn_context, action - ) - - async def on_teams_messaging_extension_fetch_task( - self, turn_context: TurnContext, action: MessagingExtensionAction - ): - self.record.append("on_teams_messaging_extension_fetch_task") - return await super().on_teams_messaging_extension_fetch_task( - turn_context, action - ) - - async def on_teams_messaging_extension_configuration_query_settings_url( - self, turn_context: TurnContext, query: MessagingExtensionQuery - ): - self.record.append( - "on_teams_messaging_extension_configuration_query_settings_url" - ) - return await super().on_teams_messaging_extension_configuration_query_settings_url( - turn_context, query - ) - - async def on_teams_messaging_extension_configuration_setting( - self, turn_context: TurnContext, settings - ): - self.record.append("on_teams_messaging_extension_configuration_setting") - return await super().on_teams_messaging_extension_configuration_setting( - turn_context, settings - ) - - async def on_teams_messaging_extension_card_button_clicked( - self, turn_context: TurnContext, card_data - ): - self.record.append("on_teams_messaging_extension_card_button_clicked") - return await super().on_teams_messaging_extension_card_button_clicked( - turn_context, card_data - ) - - async def on_teams_task_module_fetch( - self, turn_context: TurnContext, task_module_request - ): - self.record.append("on_teams_task_module_fetch") - return await super().on_teams_task_module_fetch( - turn_context, task_module_request - ) - - async def on_teams_task_module_submit( # pylint: disable=unused-argument - self, turn_context: TurnContext, task_module_request: TaskModuleRequest - ): - self.record.append("on_teams_task_module_submit") - return await super().on_teams_task_module_submit( - turn_context, task_module_request - ) - - -class NotImplementedAdapter(BotAdapter): - async def delete_activity( - self, context: TurnContext, reference: ConversationReference - ): - raise NotImplementedError() - - async def send_activities( - self, context: TurnContext, activities: List[Activity] - ) -> List[ResourceResponse]: - raise NotImplementedError() - - async def update_activity(self, context: TurnContext, activity: Activity): - raise NotImplementedError() - - -class TestTeamsActivityHandler(aiounittest.AsyncTestCase): - async def test_on_teams_channel_created_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.conversation_update, - channel_data={ - "eventType": "channelCreated", - "channel": {"id": "asdfqwerty", "name": "new_channel"}, - }, - channel_id=Channels.ms_teams, - ) - - turn_context = TurnContext(NotImplementedAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_conversation_update_activity" - assert bot.record[1] == "on_teams_channel_created" - - async def test_on_teams_channel_renamed_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.conversation_update, - channel_data={ - "eventType": "channelRenamed", - "channel": {"id": "asdfqwerty", "name": "new_channel"}, - }, - channel_id=Channels.ms_teams, - ) - - turn_context = TurnContext(NotImplementedAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_conversation_update_activity" - assert bot.record[1] == "on_teams_channel_renamed" - - async def test_on_teams_channel_deleted_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.conversation_update, - channel_data={ - "eventType": "channelDeleted", - "channel": {"id": "asdfqwerty", "name": "new_channel"}, - }, - channel_id=Channels.ms_teams, - ) - - turn_context = TurnContext(NotImplementedAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_conversation_update_activity" - assert bot.record[1] == "on_teams_channel_deleted" - - async def test_on_teams_team_renamed_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.conversation_update, - channel_data={ - "eventType": "teamRenamed", - "team": {"id": "team_id_1", "name": "new_team_name"}, - }, - channel_id=Channels.ms_teams, - ) - - turn_context = TurnContext(NotImplementedAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_conversation_update_activity" - assert bot.record[1] == "on_teams_team_renamed_activity" - - async def test_on_teams_members_removed_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.conversation_update, - channel_data={"eventType": "teamMemberRemoved"}, - members_removed=[ - ChannelAccount( - id="123", - name="test_user", - aad_object_id="asdfqwerty", - role="tester", - ) - ], - channel_id=Channels.ms_teams, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_conversation_update_activity" - assert bot.record[1] == "on_teams_members_removed" - - async def test_on_signin_verify_state(self): - # arrange - activity = Activity(type=ActivityTypes.invoke, name="signin/verifyState") - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_signin_verify_state" - - async def test_on_file_consent_accept_activity(self): - # arrange - activity = Activity( - type=ActivityTypes.invoke, - name="fileConsent/invoke", - value={"action": "accept"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_file_consent" - assert bot.record[2] == "on_teams_file_consent_accept" - - async def test_on_file_consent_decline_activity(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="fileConsent/invoke", - value={"action": "decline"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_file_consent" - assert bot.record[2] == "on_teams_file_consent_decline" - - async def test_on_file_consent_bad_action_activity(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="fileConsent/invoke", - value={"action": "bad_action"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_file_consent" - - async def test_on_teams_o365_connector_card_action(self): - # arrange - activity = Activity( - type=ActivityTypes.invoke, - name="actionableMessage/executeAction", - value={"body": "body_here", "actionId": "action_id_here"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_o365_connector_card_action" - - async def test_on_app_based_link_query(self): - # arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/query", - value={"url": "http://www.test.com"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_query" - - async def test_on_teams_messaging_extension_bot_message_preview_edit_activity(self): - # Arrange - - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/submitAction", - value={ - "data": {"key": "value"}, - "context": {"theme": "dark"}, - "commandId": "test_command", - "commandContext": "command_context_test", - "botMessagePreviewAction": "edit", - "botActivityPreview": [{"id": "activity123"}], - "messagePayload": {"id": "payloadid"}, - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" - assert bot.record[2] == "on_teams_messaging_extension_bot_message_preview_edit" - - async def test_on_teams_messaging_extension_bot_message_send_activity(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/submitAction", - value={ - "data": {"key": "value"}, - "context": {"theme": "dark"}, - "commandId": "test_command", - "commandContext": "command_context_test", - "botMessagePreviewAction": "send", - "botActivityPreview": [{"id": "123"}], - "messagePayload": {"id": "abc"}, - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" - assert bot.record[2] == "on_teams_messaging_extension_bot_message_preview_send" - - async def test_on_teams_messaging_extension_bot_message_send_activity_with_none( - self, - ): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/submitAction", - value={ - "data": {"key": "value"}, - "context": {"theme": "dark"}, - "commandId": "test_command", - "commandContext": "command_context_test", - "botMessagePreviewAction": None, - "botActivityPreview": [{"id": "test123"}], - "messagePayload": {"id": "payloadid123"}, - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" - assert bot.record[2] == "on_teams_messaging_extension_submit_action" - - async def test_on_teams_messaging_extension_bot_message_send_activity_with_empty_string( - self, - ): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/submitAction", - value={ - "data": {"key": "value"}, - "context": {"theme": "dark"}, - "commandId": "test_command", - "commandContext": "command_context_test", - "botMessagePreviewAction": "", - "botActivityPreview": [Activity().serialize()], - "messagePayload": MessageActionsPayload().serialize(), - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" - assert bot.record[2] == "on_teams_messaging_extension_submit_action" - - async def test_on_teams_messaging_extension_fetch_task(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/fetchTask", - value={ - "data": {"key": "value"}, - "context": {"theme": "dark"}, - "commandId": "test_command", - "commandContext": "command_context_test", - "botMessagePreviewAction": "message_action", - "botActivityPreview": [{"id": "123"}], - "messagePayload": {"id": "abc123"}, - }, - ) - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_fetch_task" - - async def test_on_teams_messaging_extension_configuration_query_settings_url(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/querySettingUrl", - value={ - "commandId": "test_command", - "parameters": [], - "messagingExtensionQueryOptions": {"skip": 1, "count": 1}, - "state": "state_string", - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert ( - bot.record[1] - == "on_teams_messaging_extension_configuration_query_settings_url" - ) - - async def test_on_teams_messaging_extension_configuration_setting(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/setting", - value={"key": "value"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_configuration_setting" - - async def test_on_teams_messaging_extension_card_button_clicked(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="composeExtension/onCardButtonClicked", - value={"key": "value"}, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_messaging_extension_card_button_clicked" - - async def test_on_teams_task_module_fetch(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="task/fetch", - value={ - "data": {"key": "value"}, - "context": TaskModuleRequestContext().serialize(), - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_task_module_fetch" - - async def test_on_teams_task_module_submit(self): - # Arrange - activity = Activity( - type=ActivityTypes.invoke, - name="task/submit", - value={ - "data": {"key": "value"}, - "context": TaskModuleRequestContext().serialize(), - }, - ) - - turn_context = TurnContext(SimpleAdapter(), activity) - - # Act - bot = TestingTeamsActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 2 - assert bot.record[0] == "on_invoke_activity" - assert bot.record[1] == "on_teams_task_module_submit" +from typing import List + +import aiounittest +from botbuilder.core import BotAdapter, TurnContext +from botbuilder.core.teams import TeamsActivityHandler +from botbuilder.schema import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationReference, + ResourceResponse, +) +from botbuilder.schema.teams import ( + AppBasedLinkQuery, + ChannelInfo, + FileConsentCardResponse, + MessageActionsPayload, + MessagingExtensionAction, + MessagingExtensionQuery, + O365ConnectorCardActionQuery, + TaskModuleRequest, + TaskModuleRequestContext, + TeamInfo, + TeamsChannelAccount, +) +from botframework.connector import Channels +from simple_adapter import SimpleAdapter + + +class TestingTeamsActivityHandler(TeamsActivityHandler): + __test__ = False + + def __init__(self): + self.record: List[str] = [] + + async def on_conversation_update_activity(self, turn_context: TurnContext): + self.record.append("on_conversation_update_activity") + return await super().on_conversation_update_activity(turn_context) + + async def on_teams_members_removed( + self, teams_members_removed: [TeamsChannelAccount], turn_context: TurnContext + ): + self.record.append("on_teams_members_removed") + return await super().on_teams_members_removed( + teams_members_removed, turn_context + ) + + async def on_message_activity(self, turn_context: TurnContext): + self.record.append("on_message_activity") + return await super().on_message_activity(turn_context) + + async def on_token_response_event(self, turn_context: TurnContext): + self.record.append("on_token_response_event") + return await super().on_token_response_event(turn_context) + + async def on_event(self, turn_context: TurnContext): + self.record.append("on_event") + return await super().on_event(turn_context) + + async def on_unrecognized_activity_type(self, turn_context: TurnContext): + self.record.append("on_unrecognized_activity_type") + return await super().on_unrecognized_activity_type(turn_context) + + async def on_teams_channel_created( + self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext + ): + self.record.append("on_teams_channel_created") + return await super().on_teams_channel_created( + channel_info, team_info, turn_context + ) + + async def on_teams_channel_renamed( + self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext + ): + self.record.append("on_teams_channel_renamed") + return await super().on_teams_channel_renamed( + channel_info, team_info, turn_context + ) + + async def on_teams_channel_deleted( + self, channel_info: ChannelInfo, team_info: TeamInfo, turn_context: TurnContext + ): + self.record.append("on_teams_channel_deleted") + return await super().on_teams_channel_renamed( + channel_info, team_info, turn_context + ) + + async def on_teams_team_renamed_activity( + self, team_info: TeamInfo, turn_context: TurnContext + ): + self.record.append("on_teams_team_renamed_activity") + return await super().on_teams_team_renamed_activity(team_info, turn_context) + + async def on_invoke_activity(self, turn_context: TurnContext): + self.record.append("on_invoke_activity") + return await super().on_invoke_activity(turn_context) + + async def on_teams_signin_verify_state(self, turn_context: TurnContext): + self.record.append("on_teams_signin_verify_state") + return await super().on_teams_signin_verify_state(turn_context) + + async def on_teams_file_consent( + self, + turn_context: TurnContext, + file_consent_card_response: FileConsentCardResponse, + ): + self.record.append("on_teams_file_consent") + return await super().on_teams_file_consent( + turn_context, file_consent_card_response + ) + + async def on_teams_file_consent_accept( + self, + turn_context: TurnContext, + file_consent_card_response: FileConsentCardResponse, + ): + self.record.append("on_teams_file_consent_accept") + return await super().on_teams_file_consent_accept( + turn_context, file_consent_card_response + ) + + async def on_teams_file_consent_decline( + self, + turn_context: TurnContext, + file_consent_card_response: FileConsentCardResponse, + ): + self.record.append("on_teams_file_consent_decline") + return await super().on_teams_file_consent_decline( + turn_context, file_consent_card_response + ) + + async def on_teams_o365_connector_card_action( + self, turn_context: TurnContext, query: O365ConnectorCardActionQuery + ): + self.record.append("on_teams_o365_connector_card_action") + return await super().on_teams_o365_connector_card_action(turn_context, query) + + async def on_teams_app_based_link_query( + self, turn_context: TurnContext, query: AppBasedLinkQuery + ): + self.record.append("on_teams_app_based_link_query") + return await super().on_teams_app_based_link_query(turn_context, query) + + async def on_teams_messaging_extension_query( + self, turn_context: TurnContext, query: MessagingExtensionQuery + ): + self.record.append("on_teams_messaging_extension_query") + return await super().on_teams_messaging_extension_query(turn_context, query) + + async def on_teams_messaging_extension_submit_action_dispatch( + self, turn_context: TurnContext, action: MessagingExtensionAction + ): + self.record.append("on_teams_messaging_extension_submit_action_dispatch") + return await super().on_teams_messaging_extension_submit_action_dispatch( + turn_context, action + ) + + async def on_teams_messaging_extension_submit_action( + self, turn_context: TurnContext, action: MessagingExtensionAction + ): + self.record.append("on_teams_messaging_extension_submit_action") + return await super().on_teams_messaging_extension_submit_action( + turn_context, action + ) + + async def on_teams_messaging_extension_bot_message_preview_edit( + self, turn_context: TurnContext, action: MessagingExtensionAction + ): + self.record.append("on_teams_messaging_extension_bot_message_preview_edit") + return await super().on_teams_messaging_extension_bot_message_preview_edit( + turn_context, action + ) + + async def on_teams_messaging_extension_bot_message_preview_send( + self, turn_context: TurnContext, action: MessagingExtensionAction + ): + self.record.append("on_teams_messaging_extension_bot_message_preview_send") + return await super().on_teams_messaging_extension_bot_message_preview_send( + turn_context, action + ) + + async def on_teams_messaging_extension_fetch_task( + self, turn_context: TurnContext, action: MessagingExtensionAction + ): + self.record.append("on_teams_messaging_extension_fetch_task") + return await super().on_teams_messaging_extension_fetch_task( + turn_context, action + ) + + async def on_teams_messaging_extension_configuration_query_settings_url( + self, turn_context: TurnContext, query: MessagingExtensionQuery + ): + self.record.append( + "on_teams_messaging_extension_configuration_query_settings_url" + ) + return await super().on_teams_messaging_extension_configuration_query_settings_url( + turn_context, query + ) + + async def on_teams_messaging_extension_configuration_setting( + self, turn_context: TurnContext, settings + ): + self.record.append("on_teams_messaging_extension_configuration_setting") + return await super().on_teams_messaging_extension_configuration_setting( + turn_context, settings + ) + + async def on_teams_messaging_extension_card_button_clicked( + self, turn_context: TurnContext, card_data + ): + self.record.append("on_teams_messaging_extension_card_button_clicked") + return await super().on_teams_messaging_extension_card_button_clicked( + turn_context, card_data + ) + + async def on_teams_task_module_fetch( + self, turn_context: TurnContext, task_module_request + ): + self.record.append("on_teams_task_module_fetch") + return await super().on_teams_task_module_fetch( + turn_context, task_module_request + ) + + async def on_teams_task_module_submit( # pylint: disable=unused-argument + self, turn_context: TurnContext, task_module_request: TaskModuleRequest + ): + self.record.append("on_teams_task_module_submit") + return await super().on_teams_task_module_submit( + turn_context, task_module_request + ) + + +class NotImplementedAdapter(BotAdapter): + async def delete_activity( + self, context: TurnContext, reference: ConversationReference + ): + raise NotImplementedError() + + async def send_activities( + self, context: TurnContext, activities: List[Activity] + ) -> List[ResourceResponse]: + raise NotImplementedError() + + async def update_activity(self, context: TurnContext, activity: Activity): + raise NotImplementedError() + + +class TestTeamsActivityHandler(aiounittest.AsyncTestCase): + async def test_on_teams_channel_created_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.conversation_update, + channel_data={ + "eventType": "channelCreated", + "channel": {"id": "asdfqwerty", "name": "new_channel"}, + }, + channel_id=Channels.ms_teams, + ) + + turn_context = TurnContext(NotImplementedAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_conversation_update_activity" + assert bot.record[1] == "on_teams_channel_created" + + async def test_on_teams_channel_renamed_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.conversation_update, + channel_data={ + "eventType": "channelRenamed", + "channel": {"id": "asdfqwerty", "name": "new_channel"}, + }, + channel_id=Channels.ms_teams, + ) + + turn_context = TurnContext(NotImplementedAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_conversation_update_activity" + assert bot.record[1] == "on_teams_channel_renamed" + + async def test_on_teams_channel_deleted_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.conversation_update, + channel_data={ + "eventType": "channelDeleted", + "channel": {"id": "asdfqwerty", "name": "new_channel"}, + }, + channel_id=Channels.ms_teams, + ) + + turn_context = TurnContext(NotImplementedAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_conversation_update_activity" + assert bot.record[1] == "on_teams_channel_deleted" + + async def test_on_teams_team_renamed_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.conversation_update, + channel_data={ + "eventType": "teamRenamed", + "team": {"id": "team_id_1", "name": "new_team_name"}, + }, + channel_id=Channels.ms_teams, + ) + + turn_context = TurnContext(NotImplementedAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_conversation_update_activity" + assert bot.record[1] == "on_teams_team_renamed_activity" + + async def test_on_teams_members_removed_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.conversation_update, + channel_data={"eventType": "teamMemberRemoved"}, + members_removed=[ + ChannelAccount( + id="123", + name="test_user", + aad_object_id="asdfqwerty", + role="tester", + ) + ], + channel_id=Channels.ms_teams, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_conversation_update_activity" + assert bot.record[1] == "on_teams_members_removed" + + async def test_on_signin_verify_state(self): + # arrange + activity = Activity(type=ActivityTypes.invoke, name="signin/verifyState") + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_signin_verify_state" + + async def test_on_file_consent_accept_activity(self): + # arrange + activity = Activity( + type=ActivityTypes.invoke, + name="fileConsent/invoke", + value={"action": "accept"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_file_consent" + assert bot.record[2] == "on_teams_file_consent_accept" + + async def test_on_file_consent_decline_activity(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="fileConsent/invoke", + value={"action": "decline"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_file_consent" + assert bot.record[2] == "on_teams_file_consent_decline" + + async def test_on_file_consent_bad_action_activity(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="fileConsent/invoke", + value={"action": "bad_action"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_file_consent" + + async def test_on_teams_o365_connector_card_action(self): + # arrange + activity = Activity( + type=ActivityTypes.invoke, + name="actionableMessage/executeAction", + value={"body": "body_here", "actionId": "action_id_here"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_o365_connector_card_action" + + async def test_on_app_based_link_query(self): + # arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/query", + value={"url": "http://www.test.com"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_query" + + async def test_on_teams_messaging_extension_bot_message_preview_edit_activity(self): + # Arrange + + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/submitAction", + value={ + "data": {"key": "value"}, + "context": {"theme": "dark"}, + "commandId": "test_command", + "commandContext": "command_context_test", + "botMessagePreviewAction": "edit", + "botActivityPreview": [{"id": "activity123"}], + "messagePayload": {"id": "payloadid"}, + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" + assert bot.record[2] == "on_teams_messaging_extension_bot_message_preview_edit" + + async def test_on_teams_messaging_extension_bot_message_send_activity(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/submitAction", + value={ + "data": {"key": "value"}, + "context": {"theme": "dark"}, + "commandId": "test_command", + "commandContext": "command_context_test", + "botMessagePreviewAction": "send", + "botActivityPreview": [{"id": "123"}], + "messagePayload": {"id": "abc"}, + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" + assert bot.record[2] == "on_teams_messaging_extension_bot_message_preview_send" + + async def test_on_teams_messaging_extension_bot_message_send_activity_with_none( + self, + ): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/submitAction", + value={ + "data": {"key": "value"}, + "context": {"theme": "dark"}, + "commandId": "test_command", + "commandContext": "command_context_test", + "botMessagePreviewAction": None, + "botActivityPreview": [{"id": "test123"}], + "messagePayload": {"id": "payloadid123"}, + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" + assert bot.record[2] == "on_teams_messaging_extension_submit_action" + + async def test_on_teams_messaging_extension_bot_message_send_activity_with_empty_string( + self, + ): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/submitAction", + value={ + "data": {"key": "value"}, + "context": {"theme": "dark"}, + "commandId": "test_command", + "commandContext": "command_context_test", + "botMessagePreviewAction": "", + "botActivityPreview": [Activity().serialize()], + "messagePayload": MessageActionsPayload().serialize(), + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_submit_action_dispatch" + assert bot.record[2] == "on_teams_messaging_extension_submit_action" + + async def test_on_teams_messaging_extension_fetch_task(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/fetchTask", + value={ + "data": {"key": "value"}, + "context": {"theme": "dark"}, + "commandId": "test_command", + "commandContext": "command_context_test", + "botMessagePreviewAction": "message_action", + "botActivityPreview": [{"id": "123"}], + "messagePayload": {"id": "abc123"}, + }, + ) + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_fetch_task" + + async def test_on_teams_messaging_extension_configuration_query_settings_url(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/querySettingUrl", + value={ + "commandId": "test_command", + "parameters": [], + "messagingExtensionQueryOptions": {"skip": 1, "count": 1}, + "state": "state_string", + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert ( + bot.record[1] + == "on_teams_messaging_extension_configuration_query_settings_url" + ) + + async def test_on_teams_messaging_extension_configuration_setting(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/setting", + value={"key": "value"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_configuration_setting" + + async def test_on_teams_messaging_extension_card_button_clicked(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="composeExtension/onCardButtonClicked", + value={"key": "value"}, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_messaging_extension_card_button_clicked" + + async def test_on_teams_task_module_fetch(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="task/fetch", + value={ + "data": {"key": "value"}, + "context": TaskModuleRequestContext().serialize(), + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_task_module_fetch" + + async def test_on_teams_task_module_submit(self): + # Arrange + activity = Activity( + type=ActivityTypes.invoke, + name="task/submit", + value={ + "data": {"key": "value"}, + "context": TaskModuleRequestContext().serialize(), + }, + ) + + turn_context = TurnContext(SimpleAdapter(), activity) + + # Act + bot = TestingTeamsActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 2 + assert bot.record[0] == "on_invoke_activity" + assert bot.record[1] == "on_teams_task_module_submit" diff --git a/libraries/botbuilder-core/tests/test_activity_handler.py b/libraries/botbuilder-core/tests/test_activity_handler.py index 90a49019b..5e6916dd0 100644 --- a/libraries/botbuilder-core/tests/test_activity_handler.py +++ b/libraries/botbuilder-core/tests/test_activity_handler.py @@ -1,101 +1,103 @@ -from typing import List - -import aiounittest -from botbuilder.core import ActivityHandler, BotAdapter, TurnContext -from botbuilder.schema import ( - Activity, - ActivityTypes, - ChannelAccount, - ConversationReference, - MessageReaction, - ResourceResponse, -) - - -class TestingActivityHandler(ActivityHandler): - def __init__(self): - self.record: List[str] = [] - - async def on_message_activity(self, turn_context: TurnContext): - self.record.append("on_message_activity") - return await super().on_message_activity(turn_context) - - async def on_members_added_activity( - self, members_added: ChannelAccount, turn_context: TurnContext - ): - self.record.append("on_members_added_activity") - return await super().on_members_added_activity(members_added, turn_context) - - async def on_members_removed_activity( - self, members_removed: ChannelAccount, turn_context: TurnContext - ): - self.record.append("on_members_removed_activity") - return await super().on_members_removed_activity(members_removed, turn_context) - - async def on_message_reaction_activity(self, turn_context: TurnContext): - self.record.append("on_message_reaction_activity") - return await super().on_message_reaction_activity(turn_context) - - async def on_reactions_added( - self, message_reactions: List[MessageReaction], turn_context: TurnContext - ): - self.record.append("on_reactions_added") - return await super().on_reactions_added(message_reactions, turn_context) - - async def on_reactions_removed( - self, message_reactions: List[MessageReaction], turn_context: TurnContext - ): - self.record.append("on_reactions_removed") - return await super().on_reactions_removed(message_reactions, turn_context) - - async def on_token_response_event(self, turn_context: TurnContext): - self.record.append("on_token_response_event") - return await super().on_token_response_event(turn_context) - - async def on_event(self, turn_context: TurnContext): - self.record.append("on_event") - return await super().on_event(turn_context) - - async def on_unrecognized_activity_type(self, turn_context: TurnContext): - self.record.append("on_unrecognized_activity_type") - return await super().on_unrecognized_activity_type(turn_context) - - -class NotImplementedAdapter(BotAdapter): - async def delete_activity( - self, context: TurnContext, reference: ConversationReference - ): - raise NotImplementedError() - - async def send_activities( - self, context: TurnContext, activities: List[Activity] - ) -> List[ResourceResponse]: - raise NotImplementedError() - - async def update_activity(self, context: TurnContext, activity: Activity): - raise NotImplementedError() - - -class TestActivityHandler(aiounittest.AsyncTestCase): - async def test_message_reaction(self): - # Note the code supports multiple adds and removes in the same activity though - # a channel may decide to send separate activities for each. For example, Teams - # sends separate activities each with a single add and a single remove. - - # Arrange - activity = Activity( - type=ActivityTypes.message_reaction, - reactions_added=[MessageReaction(type="sad")], - reactions_removed=[MessageReaction(type="angry")], - ) - turn_context = TurnContext(NotImplementedAdapter(), activity) - - # Act - bot = TestingActivityHandler() - await bot.on_turn(turn_context) - - # Assert - assert len(bot.record) == 3 - assert bot.record[0] == "on_message_reaction_activity" - assert bot.record[1] == "on_reactions_added" - assert bot.record[2] == "on_reactions_removed" +from typing import List + +import aiounittest +from botbuilder.core import ActivityHandler, BotAdapter, TurnContext +from botbuilder.schema import ( + Activity, + ActivityTypes, + ChannelAccount, + ConversationReference, + MessageReaction, + ResourceResponse, +) + + +class TestingActivityHandler(ActivityHandler): + __test__ = False + + def __init__(self): + self.record: List[str] = [] + + async def on_message_activity(self, turn_context: TurnContext): + self.record.append("on_message_activity") + return await super().on_message_activity(turn_context) + + async def on_members_added_activity( + self, members_added: ChannelAccount, turn_context: TurnContext + ): + self.record.append("on_members_added_activity") + return await super().on_members_added_activity(members_added, turn_context) + + async def on_members_removed_activity( + self, members_removed: ChannelAccount, turn_context: TurnContext + ): + self.record.append("on_members_removed_activity") + return await super().on_members_removed_activity(members_removed, turn_context) + + async def on_message_reaction_activity(self, turn_context: TurnContext): + self.record.append("on_message_reaction_activity") + return await super().on_message_reaction_activity(turn_context) + + async def on_reactions_added( + self, message_reactions: List[MessageReaction], turn_context: TurnContext + ): + self.record.append("on_reactions_added") + return await super().on_reactions_added(message_reactions, turn_context) + + async def on_reactions_removed( + self, message_reactions: List[MessageReaction], turn_context: TurnContext + ): + self.record.append("on_reactions_removed") + return await super().on_reactions_removed(message_reactions, turn_context) + + async def on_token_response_event(self, turn_context: TurnContext): + self.record.append("on_token_response_event") + return await super().on_token_response_event(turn_context) + + async def on_event(self, turn_context: TurnContext): + self.record.append("on_event") + return await super().on_event(turn_context) + + async def on_unrecognized_activity_type(self, turn_context: TurnContext): + self.record.append("on_unrecognized_activity_type") + return await super().on_unrecognized_activity_type(turn_context) + + +class NotImplementedAdapter(BotAdapter): + async def delete_activity( + self, context: TurnContext, reference: ConversationReference + ): + raise NotImplementedError() + + async def send_activities( + self, context: TurnContext, activities: List[Activity] + ) -> List[ResourceResponse]: + raise NotImplementedError() + + async def update_activity(self, context: TurnContext, activity: Activity): + raise NotImplementedError() + + +class TestActivityHandler(aiounittest.AsyncTestCase): + async def test_message_reaction(self): + # Note the code supports multiple adds and removes in the same activity though + # a channel may decide to send separate activities for each. For example, Teams + # sends separate activities each with a single add and a single remove. + + # Arrange + activity = Activity( + type=ActivityTypes.message_reaction, + reactions_added=[MessageReaction(type="sad")], + reactions_removed=[MessageReaction(type="angry")], + ) + turn_context = TurnContext(NotImplementedAdapter(), activity) + + # Act + bot = TestingActivityHandler() + await bot.on_turn(turn_context) + + # Assert + assert len(bot.record) == 3 + assert bot.record[0] == "on_message_reaction_activity" + assert bot.record[1] == "on_reactions_added" + assert bot.record[2] == "on_reactions_removed" diff --git a/libraries/botbuilder-core/tests/test_bot_adapter.py b/libraries/botbuilder-core/tests/test_bot_adapter.py index 9edd36c50..5f524dca2 100644 --- a/libraries/botbuilder-core/tests/test_bot_adapter.py +++ b/libraries/botbuilder-core/tests/test_bot_adapter.py @@ -1,86 +1,86 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import uuid -from typing import List -import aiounittest - -from botbuilder.core import TurnContext -from botbuilder.core.adapters import TestAdapter -from botbuilder.schema import ( - Activity, - ConversationAccount, - ConversationReference, - ChannelAccount, -) - -from simple_adapter import SimpleAdapter -from call_counting_middleware import CallCountingMiddleware -from test_message import TestMessage - - -class TestBotAdapter(aiounittest.AsyncTestCase): - def test_adapter_single_use(self): - adapter = SimpleAdapter() - adapter.use(CallCountingMiddleware()) - - def test_adapter_use_chaining(self): - adapter = SimpleAdapter() - adapter.use(CallCountingMiddleware()).use(CallCountingMiddleware()) - - async def test_pass_resource_responses_through(self): - def validate_responses( # pylint: disable=unused-argument - activities: List[Activity], - ): - pass # no need to do anything. - - adapter = SimpleAdapter(call_on_send=validate_responses) - context = TurnContext(adapter, Activity()) - - activity_id = str(uuid.uuid1()) - activity = TestMessage.message(activity_id) - - resource_response = await context.send_activity(activity) - self.assertTrue( - resource_response.id != activity_id, "Incorrect response Id returned" - ) - - async def test_continue_conversation_direct_msg(self): - callback_invoked = False - adapter = TestAdapter() - reference = ConversationReference( - activity_id="activityId", - bot=ChannelAccount(id="channelId", name="testChannelAccount", role="bot"), - channel_id="testChannel", - service_url="testUrl", - conversation=ConversationAccount( - conversation_type="", - id="testConversationId", - is_group=False, - name="testConversationName", - role="user", - ), - user=ChannelAccount(id="channelId", name="testChannelAccount", role="bot"), - ) - - async def continue_callback(turn_context): # pylint: disable=unused-argument - nonlocal callback_invoked - callback_invoked = True - - await adapter.continue_conversation(reference, continue_callback, "MyBot") - self.assertTrue(callback_invoked) - - async def test_turn_error(self): - async def on_error(turn_context: TurnContext, err: Exception): - nonlocal self - self.assertIsNotNone(turn_context, "turn_context not found.") - self.assertIsNotNone(err, "error not found.") - self.assertEqual(err.__class__, Exception, "unexpected error thrown.") - - adapter = SimpleAdapter() - adapter.on_turn_error = on_error - - def handler(context: TurnContext): # pylint: disable=unused-argument - raise Exception - - await adapter.process_request(TestMessage.message(), handler) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import uuid +from typing import List +import aiounittest + +from botbuilder.core import TurnContext +from botbuilder.core.adapters import TestAdapter +from botbuilder.schema import ( + Activity, + ConversationAccount, + ConversationReference, + ChannelAccount, +) + +from simple_adapter import SimpleAdapter +from call_counting_middleware import CallCountingMiddleware +from test_message import TestMessage + + +class TestBotAdapter(aiounittest.AsyncTestCase): + def test_adapter_single_use(self): + adapter = SimpleAdapter() + adapter.use(CallCountingMiddleware()) + + def test_adapter_use_chaining(self): + adapter = SimpleAdapter() + adapter.use(CallCountingMiddleware()).use(CallCountingMiddleware()) + + async def test_pass_resource_responses_through(self): + def validate_responses( # pylint: disable=unused-argument + activities: List[Activity], + ): + pass # no need to do anything. + + adapter = SimpleAdapter(call_on_send=validate_responses) + context = TurnContext(adapter, Activity()) + + activity_id = str(uuid.uuid1()) + activity = TestMessage.message(activity_id) + + resource_response = await context.send_activity(activity) + self.assertTrue( + resource_response.id != activity_id, "Incorrect response Id returned" + ) + + async def test_continue_conversation_direct_msg(self): + callback_invoked = False + adapter = TestAdapter() + reference = ConversationReference( + activity_id="activityId", + bot=ChannelAccount(id="channelId", name="testChannelAccount", role="bot"), + channel_id="testChannel", + service_url="testUrl", + conversation=ConversationAccount( + conversation_type="", + id="testConversationId", + is_group=False, + name="testConversationName", + role="user", + ), + user=ChannelAccount(id="channelId", name="testChannelAccount", role="bot"), + ) + + async def continue_callback(turn_context): # pylint: disable=unused-argument + nonlocal callback_invoked + callback_invoked = True + + await adapter.continue_conversation(reference, continue_callback, "MyBot") + self.assertTrue(callback_invoked) + + async def test_turn_error(self): + async def on_error(turn_context: TurnContext, err: Exception): + nonlocal self + self.assertIsNotNone(turn_context, "turn_context not found.") + self.assertIsNotNone(err, "error not found.") + self.assertEqual(err.__class__, Exception, "unexpected error thrown.") + + adapter = SimpleAdapter() + adapter.on_turn_error = on_error + + def handler(context: TurnContext): # pylint: disable=unused-argument + raise Exception + + await adapter.process_request(TestMessage.message(), handler) diff --git a/libraries/botbuilder-core/tests/test_bot_state.py b/libraries/botbuilder-core/tests/test_bot_state.py index 13dec6d53..2c0eb815e 100644 --- a/libraries/botbuilder-core/tests/test_bot_state.py +++ b/libraries/botbuilder-core/tests/test_bot_state.py @@ -1,483 +1,485 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from unittest.mock import MagicMock -import aiounittest - -from botbuilder.core import ( - BotState, - ConversationState, - MemoryStorage, - Storage, - StoreItem, - TurnContext, - UserState, -) -from botbuilder.core.adapters import TestAdapter -from botbuilder.schema import Activity, ConversationAccount - -from test_utilities import TestUtilities - -RECEIVED_MESSAGE = Activity(type="message", text="received") -STORAGE_KEY = "stateKey" - - -def cached_state(context, state_key): - cached = context.services.get(state_key) - return cached["state"] if cached is not None else None - - -def key_factory(context): - assert context is not None - return STORAGE_KEY - - -class BotStateForTest(BotState): - def __init__(self, storage: Storage): - super().__init__(storage, f"BotState:BotState") - - def get_storage_key(self, turn_context: TurnContext) -> str: - return f"botstate/{turn_context.activity.channel_id}/{turn_context.activity.conversation.id}/BotState" - - -class CustomState(StoreItem): - def __init__(self, custom_string: str = None, e_tag: str = "*"): - super().__init__(custom_string=custom_string, e_tag=e_tag) - - -class TestPocoState: - def __init__(self, value=None): - self.value = value - - -class TestBotState(aiounittest.AsyncTestCase): - storage = MemoryStorage() - adapter = TestAdapter() - context = TurnContext(adapter, RECEIVED_MESSAGE) - middleware = BotState(storage, key_factory) - - def test_state_empty_name(self): - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - - # Act - with self.assertRaises(TypeError) as _: - user_state.create_property("") - - def test_state_none_name(self): - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - - # Act - with self.assertRaises(TypeError) as _: - user_state.create_property(None) - - async def test_storage_not_called_no_changes(self): - """Verify storage not called when no changes are made""" - # Mock a storage provider, which counts read/writes - dictionary = {} - - async def mock_write_result(self): # pylint: disable=unused-argument - return - - async def mock_read_result(self): # pylint: disable=unused-argument - return {} - - mock_storage = MemoryStorage(dictionary) - mock_storage.write = MagicMock(side_effect=mock_write_result) - mock_storage.read = MagicMock(side_effect=mock_read_result) - - # Arrange - user_state = UserState(mock_storage) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property_a") - self.assertEqual(mock_storage.write.call_count, 0) - await user_state.save_changes(context) - await property_a.set(context, "hello") - self.assertEqual(mock_storage.read.call_count, 1) # Initial save bumps count - self.assertEqual(mock_storage.write.call_count, 0) # Initial save bumps count - await property_a.set(context, "there") - self.assertEqual( - mock_storage.write.call_count, 0 - ) # Set on property should not bump - await user_state.save_changes(context) - self.assertEqual(mock_storage.write.call_count, 1) # Explicit save should bump - value_a = await property_a.get(context) - self.assertEqual("there", value_a) - self.assertEqual(mock_storage.write.call_count, 1) # Gets should not bump - await user_state.save_changes(context) - self.assertEqual(mock_storage.write.call_count, 1) - await property_a.delete(context) # Delete alone no bump - self.assertEqual(mock_storage.write.call_count, 1) - await user_state.save_changes(context) # Save when dirty should bump - self.assertEqual(mock_storage.write.call_count, 2) - self.assertEqual(mock_storage.read.call_count, 1) - await user_state.save_changes(context) # Save not dirty should not bump - self.assertEqual(mock_storage.write.call_count, 2) - self.assertEqual(mock_storage.read.call_count, 1) - - async def test_state_set_no_load(self): - """Should be able to set a property with no Load""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property_a") - await property_a.set(context, "hello") - - async def test_state_multiple_loads(self): - """Should be able to load multiple times""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - user_state.create_property("property_a") - await user_state.load(context) - await user_state.load(context) - - async def test_state_get_no_load_with_default(self): - """Should be able to get a property with no Load and default""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property_a") - value_a = await property_a.get(context, lambda: "Default!") - self.assertEqual("Default!", value_a) - - async def test_state_get_no_load_no_default(self): - """Cannot get a string with no default set""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property_a") - value_a = await property_a.get(context) - - # Assert - self.assertIsNone(value_a) - - async def test_state_poco_no_default(self): - """Cannot get a POCO with no default set""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - test_property = user_state.create_property("test") - value = await test_property.get(context) - - # Assert - self.assertIsNone(value) - - async def test_state_bool_no_default(self): - """Cannot get a bool with no default set""" - # Arange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - test_property = user_state.create_property("test") - value = await test_property.get(context) - - # Assert - self.assertFalse(value) - - async def test_state_set_after_save(self): - """Verify setting property after save""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property-a") - property_b = user_state.create_property("property-b") - - await user_state.load(context) - await property_a.set(context, "hello") - await property_b.set(context, "world") - await user_state.save_changes(context) - - await property_a.set(context, "hello2") - - async def test_state_multiple_save(self): - """Verify multiple saves""" - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property-a") - property_b = user_state.create_property("property-b") - - await user_state.load(context) - await property_a.set(context, "hello") - await property_b.set(context, "world") - await user_state.save_changes(context) - - await property_a.set(context, "hello2") - await user_state.save_changes(context) - value_a = await property_a.get(context) - self.assertEqual("hello2", value_a) - - async def test_load_set_save(self): - # Arrange - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - - # Act - property_a = user_state.create_property("property-a") - property_b = user_state.create_property("property-b") - - await user_state.load(context) - await property_a.set(context, "hello") - await property_b.set(context, "world") - await user_state.save_changes(context) - - # Assert - obj = dictionary["EmptyContext/users/empty@empty.context.org"] - self.assertEqual("hello", obj["property-a"]) - self.assertEqual("world", obj["property-b"]) - - async def test_load_set_save_twice(self): - # Arrange - dictionary = {} - context = TestUtilities.create_empty_context() - - # Act - user_state = UserState(MemoryStorage(dictionary)) - - property_a = user_state.create_property("property-a") - property_b = user_state.create_property("property-b") - property_c = user_state.create_property("property-c") - - await user_state.load(context) - await property_a.set(context, "hello") - await property_b.set(context, "world") - await property_c.set(context, "test") - await user_state.save_changes(context) - - # Assert - obj = dictionary["EmptyContext/users/empty@empty.context.org"] - self.assertEqual("hello", obj["property-a"]) - self.assertEqual("world", obj["property-b"]) - - # Act 2 - user_state2 = UserState(MemoryStorage(dictionary)) - - property_a2 = user_state2.create_property("property-a") - property_b2 = user_state2.create_property("property-b") - - await user_state2.load(context) - await property_a2.set(context, "hello-2") - await property_b2.set(context, "world-2") - await user_state2.save_changes(context) - - # Assert 2 - obj2 = dictionary["EmptyContext/users/empty@empty.context.org"] - self.assertEqual("hello-2", obj2["property-a"]) - self.assertEqual("world-2", obj2["property-b"]) - self.assertEqual("test", obj2["property-c"]) - - async def test_load_save_delete(self): - # Arrange - dictionary = {} - context = TestUtilities.create_empty_context() - - # Act - user_state = UserState(MemoryStorage(dictionary)) - - property_a = user_state.create_property("property-a") - property_b = user_state.create_property("property-b") - - await user_state.load(context) - await property_a.set(context, "hello") - await property_b.set(context, "world") - await user_state.save_changes(context) - - # Assert - obj = dictionary["EmptyContext/users/empty@empty.context.org"] - self.assertEqual("hello", obj["property-a"]) - self.assertEqual("world", obj["property-b"]) - - # Act 2 - user_state2 = UserState(MemoryStorage(dictionary)) - - property_a2 = user_state2.create_property("property-a") - property_b2 = user_state2.create_property("property-b") - - await user_state2.load(context) - await property_a2.set(context, "hello-2") - await property_b2.delete(context) - await user_state2.save_changes(context) - - # Assert 2 - obj2 = dictionary["EmptyContext/users/empty@empty.context.org"] - self.assertEqual("hello-2", obj2["property-a"]) - with self.assertRaises(KeyError) as _: - obj2["property-b"] # pylint: disable=pointless-statement - - async def test_state_use_bot_state_directly(self): - async def exec_test(context: TurnContext): - # pylint: disable=unnecessary-lambda - bot_state_manager = BotStateForTest(MemoryStorage()) - test_property = bot_state_manager.create_property("test") - - # read initial state object - await bot_state_manager.load(context) - - custom_state = await test_property.get(context, lambda: CustomState()) - - # this should be a 'CustomState' as nothing is currently stored in storage - assert isinstance(custom_state, CustomState) - - # amend property and write to storage - custom_state.custom_string = "test" - await bot_state_manager.save_changes(context) - - custom_state.custom_string = "asdfsadf" - - # read into context again - await bot_state_manager.load(context, True) - - custom_state = await test_property.get(context) - - # check object read from value has the correct value for custom_string - assert custom_state.custom_string == "test" - - adapter = TestAdapter(exec_test) - await adapter.send("start") - - async def test_user_state_bad_from_throws(self): - dictionary = {} - user_state = UserState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - context.activity.from_property = None - test_property = user_state.create_property("test") - with self.assertRaises(AttributeError): - await test_property.get(context) - - async def test_conversation_state_bad_conversation_throws(self): - dictionary = {} - user_state = ConversationState(MemoryStorage(dictionary)) - context = TestUtilities.create_empty_context() - context.activity.conversation = None - test_property = user_state.create_property("test") - with self.assertRaises(AttributeError): - await test_property.get(context) - - async def test_clear_and_save(self): - # pylint: disable=unnecessary-lambda - turn_context = TestUtilities.create_empty_context() - turn_context.activity.conversation = ConversationAccount(id="1234") - - storage = MemoryStorage({}) - - # Turn 0 - bot_state1 = ConversationState(storage) - ( - await bot_state1.create_property("test-name").get( - turn_context, lambda: TestPocoState() - ) - ).value = "test-value" - await bot_state1.save_changes(turn_context) - - # Turn 1 - bot_state2 = ConversationState(storage) - value1 = ( - await bot_state2.create_property("test-name").get( - turn_context, lambda: TestPocoState(value="default-value") - ) - ).value - - assert value1 == "test-value" - - # Turn 2 - bot_state3 = ConversationState(storage) - await bot_state3.clear_state(turn_context) - await bot_state3.save_changes(turn_context) - - # Turn 3 - bot_state4 = ConversationState(storage) - value2 = ( - await bot_state4.create_property("test-name").get( - turn_context, lambda: TestPocoState(value="default-value") - ) - ).value - - assert value2, "default-value" - - async def test_bot_state_delete(self): - # pylint: disable=unnecessary-lambda - turn_context = TestUtilities.create_empty_context() - turn_context.activity.conversation = ConversationAccount(id="1234") - - storage = MemoryStorage({}) - - # Turn 0 - bot_state1 = ConversationState(storage) - ( - await bot_state1.create_property("test-name").get( - turn_context, lambda: TestPocoState() - ) - ).value = "test-value" - await bot_state1.save_changes(turn_context) - - # Turn 1 - bot_state2 = ConversationState(storage) - value1 = ( - await bot_state2.create_property("test-name").get( - turn_context, lambda: TestPocoState(value="default-value") - ) - ).value - - assert value1 == "test-value" - - # Turn 2 - bot_state3 = ConversationState(storage) - await bot_state3.delete(turn_context) - - # Turn 3 - bot_state4 = ConversationState(storage) - value2 = ( - await bot_state4.create_property("test-name").get( - turn_context, lambda: TestPocoState(value="default-value") - ) - ).value - - assert value2 == "default-value" - - async def test_bot_state_get(self): - # pylint: disable=unnecessary-lambda - turn_context = TestUtilities.create_empty_context() - turn_context.activity.conversation = ConversationAccount(id="1234") - - storage = MemoryStorage({}) - - conversation_state = ConversationState(storage) - ( - await conversation_state.create_property("test-name").get( - turn_context, lambda: TestPocoState() - ) - ).value = "test-value" - - result = conversation_state.get(turn_context) - - assert result["test-name"].value == "test-value" +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from unittest.mock import MagicMock +import aiounittest + +from botbuilder.core import ( + BotState, + ConversationState, + MemoryStorage, + Storage, + StoreItem, + TurnContext, + UserState, +) +from botbuilder.core.adapters import TestAdapter +from botbuilder.schema import Activity, ConversationAccount + +from test_utilities import TestUtilities + +RECEIVED_MESSAGE = Activity(type="message", text="received") +STORAGE_KEY = "stateKey" + + +def cached_state(context, state_key): + cached = context.services.get(state_key) + return cached["state"] if cached is not None else None + + +def key_factory(context): + assert context is not None + return STORAGE_KEY + + +class BotStateForTest(BotState): + def __init__(self, storage: Storage): + super().__init__(storage, f"BotState:BotState") + + def get_storage_key(self, turn_context: TurnContext) -> str: + return f"botstate/{turn_context.activity.channel_id}/{turn_context.activity.conversation.id}/BotState" + + +class CustomState(StoreItem): + def __init__(self, custom_string: str = None, e_tag: str = "*"): + super().__init__(custom_string=custom_string, e_tag=e_tag) + + +class TestPocoState: + __test__ = False + + def __init__(self, value=None): + self.value = value + + +class TestBotState(aiounittest.AsyncTestCase): + storage = MemoryStorage() + adapter = TestAdapter() + context = TurnContext(adapter, RECEIVED_MESSAGE) + middleware = BotState(storage, key_factory) + + def test_state_empty_name(self): + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + + # Act + with self.assertRaises(TypeError) as _: + user_state.create_property("") + + def test_state_none_name(self): + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + + # Act + with self.assertRaises(TypeError) as _: + user_state.create_property(None) + + async def test_storage_not_called_no_changes(self): + """Verify storage not called when no changes are made""" + # Mock a storage provider, which counts read/writes + dictionary = {} + + async def mock_write_result(self): # pylint: disable=unused-argument + return + + async def mock_read_result(self): # pylint: disable=unused-argument + return {} + + mock_storage = MemoryStorage(dictionary) + mock_storage.write = MagicMock(side_effect=mock_write_result) + mock_storage.read = MagicMock(side_effect=mock_read_result) + + # Arrange + user_state = UserState(mock_storage) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property_a") + self.assertEqual(mock_storage.write.call_count, 0) + await user_state.save_changes(context) + await property_a.set(context, "hello") + self.assertEqual(mock_storage.read.call_count, 1) # Initial save bumps count + self.assertEqual(mock_storage.write.call_count, 0) # Initial save bumps count + await property_a.set(context, "there") + self.assertEqual( + mock_storage.write.call_count, 0 + ) # Set on property should not bump + await user_state.save_changes(context) + self.assertEqual(mock_storage.write.call_count, 1) # Explicit save should bump + value_a = await property_a.get(context) + self.assertEqual("there", value_a) + self.assertEqual(mock_storage.write.call_count, 1) # Gets should not bump + await user_state.save_changes(context) + self.assertEqual(mock_storage.write.call_count, 1) + await property_a.delete(context) # Delete alone no bump + self.assertEqual(mock_storage.write.call_count, 1) + await user_state.save_changes(context) # Save when dirty should bump + self.assertEqual(mock_storage.write.call_count, 2) + self.assertEqual(mock_storage.read.call_count, 1) + await user_state.save_changes(context) # Save not dirty should not bump + self.assertEqual(mock_storage.write.call_count, 2) + self.assertEqual(mock_storage.read.call_count, 1) + + async def test_state_set_no_load(self): + """Should be able to set a property with no Load""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property_a") + await property_a.set(context, "hello") + + async def test_state_multiple_loads(self): + """Should be able to load multiple times""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + user_state.create_property("property_a") + await user_state.load(context) + await user_state.load(context) + + async def test_state_get_no_load_with_default(self): + """Should be able to get a property with no Load and default""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property_a") + value_a = await property_a.get(context, lambda: "Default!") + self.assertEqual("Default!", value_a) + + async def test_state_get_no_load_no_default(self): + """Cannot get a string with no default set""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property_a") + value_a = await property_a.get(context) + + # Assert + self.assertIsNone(value_a) + + async def test_state_poco_no_default(self): + """Cannot get a POCO with no default set""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + test_property = user_state.create_property("test") + value = await test_property.get(context) + + # Assert + self.assertIsNone(value) + + async def test_state_bool_no_default(self): + """Cannot get a bool with no default set""" + # Arange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + test_property = user_state.create_property("test") + value = await test_property.get(context) + + # Assert + self.assertFalse(value) + + async def test_state_set_after_save(self): + """Verify setting property after save""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property-a") + property_b = user_state.create_property("property-b") + + await user_state.load(context) + await property_a.set(context, "hello") + await property_b.set(context, "world") + await user_state.save_changes(context) + + await property_a.set(context, "hello2") + + async def test_state_multiple_save(self): + """Verify multiple saves""" + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property-a") + property_b = user_state.create_property("property-b") + + await user_state.load(context) + await property_a.set(context, "hello") + await property_b.set(context, "world") + await user_state.save_changes(context) + + await property_a.set(context, "hello2") + await user_state.save_changes(context) + value_a = await property_a.get(context) + self.assertEqual("hello2", value_a) + + async def test_load_set_save(self): + # Arrange + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + + # Act + property_a = user_state.create_property("property-a") + property_b = user_state.create_property("property-b") + + await user_state.load(context) + await property_a.set(context, "hello") + await property_b.set(context, "world") + await user_state.save_changes(context) + + # Assert + obj = dictionary["EmptyContext/users/empty@empty.context.org"] + self.assertEqual("hello", obj["property-a"]) + self.assertEqual("world", obj["property-b"]) + + async def test_load_set_save_twice(self): + # Arrange + dictionary = {} + context = TestUtilities.create_empty_context() + + # Act + user_state = UserState(MemoryStorage(dictionary)) + + property_a = user_state.create_property("property-a") + property_b = user_state.create_property("property-b") + property_c = user_state.create_property("property-c") + + await user_state.load(context) + await property_a.set(context, "hello") + await property_b.set(context, "world") + await property_c.set(context, "test") + await user_state.save_changes(context) + + # Assert + obj = dictionary["EmptyContext/users/empty@empty.context.org"] + self.assertEqual("hello", obj["property-a"]) + self.assertEqual("world", obj["property-b"]) + + # Act 2 + user_state2 = UserState(MemoryStorage(dictionary)) + + property_a2 = user_state2.create_property("property-a") + property_b2 = user_state2.create_property("property-b") + + await user_state2.load(context) + await property_a2.set(context, "hello-2") + await property_b2.set(context, "world-2") + await user_state2.save_changes(context) + + # Assert 2 + obj2 = dictionary["EmptyContext/users/empty@empty.context.org"] + self.assertEqual("hello-2", obj2["property-a"]) + self.assertEqual("world-2", obj2["property-b"]) + self.assertEqual("test", obj2["property-c"]) + + async def test_load_save_delete(self): + # Arrange + dictionary = {} + context = TestUtilities.create_empty_context() + + # Act + user_state = UserState(MemoryStorage(dictionary)) + + property_a = user_state.create_property("property-a") + property_b = user_state.create_property("property-b") + + await user_state.load(context) + await property_a.set(context, "hello") + await property_b.set(context, "world") + await user_state.save_changes(context) + + # Assert + obj = dictionary["EmptyContext/users/empty@empty.context.org"] + self.assertEqual("hello", obj["property-a"]) + self.assertEqual("world", obj["property-b"]) + + # Act 2 + user_state2 = UserState(MemoryStorage(dictionary)) + + property_a2 = user_state2.create_property("property-a") + property_b2 = user_state2.create_property("property-b") + + await user_state2.load(context) + await property_a2.set(context, "hello-2") + await property_b2.delete(context) + await user_state2.save_changes(context) + + # Assert 2 + obj2 = dictionary["EmptyContext/users/empty@empty.context.org"] + self.assertEqual("hello-2", obj2["property-a"]) + with self.assertRaises(KeyError) as _: + obj2["property-b"] # pylint: disable=pointless-statement + + async def test_state_use_bot_state_directly(self): + async def exec_test(context: TurnContext): + # pylint: disable=unnecessary-lambda + bot_state_manager = BotStateForTest(MemoryStorage()) + test_property = bot_state_manager.create_property("test") + + # read initial state object + await bot_state_manager.load(context) + + custom_state = await test_property.get(context, lambda: CustomState()) + + # this should be a 'CustomState' as nothing is currently stored in storage + assert isinstance(custom_state, CustomState) + + # amend property and write to storage + custom_state.custom_string = "test" + await bot_state_manager.save_changes(context) + + custom_state.custom_string = "asdfsadf" + + # read into context again + await bot_state_manager.load(context, True) + + custom_state = await test_property.get(context) + + # check object read from value has the correct value for custom_string + assert custom_state.custom_string == "test" + + adapter = TestAdapter(exec_test) + await adapter.send("start") + + async def test_user_state_bad_from_throws(self): + dictionary = {} + user_state = UserState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + context.activity.from_property = None + test_property = user_state.create_property("test") + with self.assertRaises(AttributeError): + await test_property.get(context) + + async def test_conversation_state_bad_conversation_throws(self): + dictionary = {} + user_state = ConversationState(MemoryStorage(dictionary)) + context = TestUtilities.create_empty_context() + context.activity.conversation = None + test_property = user_state.create_property("test") + with self.assertRaises(AttributeError): + await test_property.get(context) + + async def test_clear_and_save(self): + # pylint: disable=unnecessary-lambda + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + # Turn 0 + bot_state1 = ConversationState(storage) + ( + await bot_state1.create_property("test-name").get( + turn_context, lambda: TestPocoState() + ) + ).value = "test-value" + await bot_state1.save_changes(turn_context) + + # Turn 1 + bot_state2 = ConversationState(storage) + value1 = ( + await bot_state2.create_property("test-name").get( + turn_context, lambda: TestPocoState(value="default-value") + ) + ).value + + assert value1 == "test-value" + + # Turn 2 + bot_state3 = ConversationState(storage) + await bot_state3.clear_state(turn_context) + await bot_state3.save_changes(turn_context) + + # Turn 3 + bot_state4 = ConversationState(storage) + value2 = ( + await bot_state4.create_property("test-name").get( + turn_context, lambda: TestPocoState(value="default-value") + ) + ).value + + assert value2, "default-value" + + async def test_bot_state_delete(self): + # pylint: disable=unnecessary-lambda + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + # Turn 0 + bot_state1 = ConversationState(storage) + ( + await bot_state1.create_property("test-name").get( + turn_context, lambda: TestPocoState() + ) + ).value = "test-value" + await bot_state1.save_changes(turn_context) + + # Turn 1 + bot_state2 = ConversationState(storage) + value1 = ( + await bot_state2.create_property("test-name").get( + turn_context, lambda: TestPocoState(value="default-value") + ) + ).value + + assert value1 == "test-value" + + # Turn 2 + bot_state3 = ConversationState(storage) + await bot_state3.delete(turn_context) + + # Turn 3 + bot_state4 = ConversationState(storage) + value2 = ( + await bot_state4.create_property("test-name").get( + turn_context, lambda: TestPocoState(value="default-value") + ) + ).value + + assert value2 == "default-value" + + async def test_bot_state_get(self): + # pylint: disable=unnecessary-lambda + turn_context = TestUtilities.create_empty_context() + turn_context.activity.conversation = ConversationAccount(id="1234") + + storage = MemoryStorage({}) + + conversation_state = ConversationState(storage) + ( + await conversation_state.create_property("test-name").get( + turn_context, lambda: TestPocoState() + ) + ).value = "test-value" + + result = conversation_state.get(turn_context) + + assert result["test-name"].value == "test-value" diff --git a/libraries/botbuilder-dialogs/botbuilder/dialogs/waterfall_dialog.py b/libraries/botbuilder-dialogs/botbuilder/dialogs/waterfall_dialog.py index 678f341f8..bced214fb 100644 --- a/libraries/botbuilder-dialogs/botbuilder/dialogs/waterfall_dialog.py +++ b/libraries/botbuilder-dialogs/botbuilder/dialogs/waterfall_dialog.py @@ -59,7 +59,7 @@ async def begin_dialog( properties = {} properties["DialogId"] = self.id properties["InstanceId"] = instance_id - self.telemetry_client.track_event("WaterfallStart", properties=properties) + self.telemetry_client.track_event("WaterfallStart", properties) # Run first stepkinds return await self.run_step(dialog_context, 0, DialogReason.BeginCalled, None)