Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions _test_unstructured_client/integration/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import tempfile
from pathlib import Path

import httpx
import json
import pytest
Expand Down Expand Up @@ -102,6 +105,87 @@ def test_integration_split_pdf_has_same_output_as_non_split(
)
assert len(diff) == 0

@pytest.mark.parametrize( ("filename", "expected_ok", "strategy"), [
("_sample_docs/layout-parser-paper.pdf", True, "hi_res"), # 16
]# pages
)
@pytest.mark.parametrize( ("use_caching", "cache_dir"), [
(True, None), # Use default cache dir
(True, Path(tempfile.gettempdir()) / "test_integration_unstructured_client1"), # Use custom cache dir
(False, None), # Don't use caching
(False, Path(tempfile.gettempdir()) / "test_integration_unstructured_client2"), # Don't use caching, use custom cache dir
])
def test_integration_split_pdf_with_caching(
filename: str, expected_ok: bool, strategy: str, use_caching: bool,
cache_dir: Path | None
):
try:
response = requests.get("http://localhost:8000/general/docs")
assert response.status_code == 200, "The unstructured-api is not running on localhost:8000"
except requests.exceptions.ConnectionError:
assert False, "The unstructured-api is not running on localhost:8000"

client = UnstructuredClient(api_key_auth=FAKE_KEY, server_url="localhost:8000")

with open(filename, "rb") as f:
files = shared.Files(
content=f.read(),
file_name=filename,
)

if not expected_ok:
# This will append .pdf to filename to fool first line of filetype detection, to simulate decoding error
files.file_name += ".pdf"

parameters = shared.PartitionParameters(
files=files,
strategy=strategy,
languages=["eng"],
split_pdf_page=True,
split_pdf_cache_tmp_data=use_caching,
split_pdf_cache_dir=cache_dir,
)

req = operations.PartitionRequest(
partition_parameters=parameters
)

try:
resp_split = client.general.partition(request=req)
except (HTTPValidationError, AttributeError) as exc:
if not expected_ok:
assert "File does not appear to be a valid PDF" in str(exc)
return
else:
assert exc is None

parameters.split_pdf_page = False

req = operations.PartitionRequest(
partition_parameters=parameters
)

resp_single = client.general.partition(request=req)

assert len(resp_split.elements) == len(resp_single.elements)
assert resp_split.content_type == resp_single.content_type
assert resp_split.status_code == resp_single.status_code

diff = DeepDiff(
t1=resp_split.elements,
t2=resp_single.elements,
exclude_regex_paths=[
r"root\[\d+\]\['metadata'\]\['parent_id'\]",
r"root\[\d+\]\['element_id'\]",
],
)
assert len(diff) == 0

# make sure the cache dir was cleaned if passed explicitly
if cache_dir:
assert not Path(cache_dir).exists()



def test_integration_split_pdf_for_file_with_no_name():
"""
Expand Down
72 changes: 72 additions & 0 deletions _test_unstructured_client/unit/test_request_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Get unit tests for request_utils.py module
import httpx
import pytest

from unstructured_client._hooks.custom.request_utils import create_pdf_chunk_request_params, get_multipart_stream_fields
from unstructured_client.models import shared


# make the above test using @pytest.mark.parametrize
@pytest.mark.parametrize(("input_request", "expected"), [
(httpx.Request("POST", "http://localhost:8000", data={}, headers={"Content-Type": "multipart/form-data"}), {}),
(httpx.Request("POST", "http://localhost:8000", data={"hello": "world"}, headers={"Content-Type": "application/json"}), {}),
(httpx.Request(
"POST",
"http://localhost:8000",
data={"hello": "world"},
files={"files": ("hello.pdf", b"hello", "application/pdf")},
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
{
"hello": "world",
"files": {
"content_type":"application/pdf",
"filename": "hello.pdf",
"file": b"hello",
}
}
),
])
def test_get_multipart_stream_fields(input_request, expected):
fields = get_multipart_stream_fields(input_request)
assert fields == expected

def test_multipart_stream_fields_raises_value_error_when_filename_is_not_set():
with pytest.raises(ValueError):
get_multipart_stream_fields(httpx.Request(
"POST",
"http://localhost:8000",
data={"hello": "world"},
files={"files": ("", b"hello", "application/pdf")},
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
)

@pytest.mark.parametrize(("input_form_data", "page_number", "expected_form_data"), [
(
{"hello": "world"},
2,
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
),
(
{"hello": "world", "split_pdf_page": "true"},
2,
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
),
(
{"hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
3,
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
),
(
{"split_pdf_page_range[]": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
3,
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
),
(
{"split_pdf_page_range": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
4,
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "4"}
),
])
def test_create_pdf_chunk_request_params(input_form_data, page_number, expected_form_data):
form_data = create_pdf_chunk_request_params(input_form_data, page_number)
assert form_data == expected_form_data
125 changes: 40 additions & 85 deletions _test_unstructured_client/unit/test_split_pdf_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from asyncio import Task
from collections import Counter
from functools import partial
from typing import Coroutine

import httpx
Expand Down Expand Up @@ -53,29 +54,6 @@ async def example():
assert hook.api_successful_responses.get(operation_id) is None


def test_unit_prepare_request_payload():
"""Test prepare request payload method properly sets split_pdf_page to 'false'
and removes files key."""
test_form_data = {
"files": ("test_file.pdf", b"test_file_content"),
"split_pdf_page": "true",
"parameter_1": "value_1",
"parameter_2": "value_2",
"parameter_3": "value_3",
}
expected_form_data = {
"split_pdf_page": "false",
"parameter_1": "value_1",
"parameter_2": "value_2",
"parameter_3": "value_3",
}

payload = request_utils.prepare_request_payload(test_form_data)

assert payload != test_form_data
assert payload, expected_form_data


def test_unit_prepare_request_headers():
"""Test prepare request headers method properly removes Content-Type and Content-Length headers."""
test_headers = {
Expand Down Expand Up @@ -224,61 +202,31 @@ def test_unit_parse_form_data_none_filename_error():
form_utils.parse_form_data(decoded_data)


def test_unit_is_pdf_valid_pdf():
"""Test is pdf method returns True for valid pdf file with filename."""
def test_unit_is_pdf_valid_pdf_when_passing_file_object():
"""Test is pdf method returns pdf object for valid pdf file with filename."""
filename = "_sample_docs/layout-parser-paper-fast.pdf"

with open(filename, "rb") as f:
file = shared.Files(
content=f.read(),
file_name=filename,
)

result = pdf_utils.is_pdf(file)
result = pdf_utils.read_pdf(f)

assert result is True
assert result is not None


def test_unit_is_pdf_valid_pdf_without_file_extension():
"""Test is pdf method returns True for file with valid pdf content without basing on file extension."""
def test_unit_is_pdf_valid_pdf_when_passing_binary_content():
"""Test is pdf method returns pdf object for file with valid pdf content"""
filename = "_sample_docs/layout-parser-paper-fast.pdf"

with open(filename, "rb") as f:
file = shared.Files(
content=f.read(),
file_name="uuid1234",
)

result = pdf_utils.is_pdf(file)

assert result is True


def test_unit_is_pdf_invalid_extension():
"""Test is pdf method returns False for file with invalid extension."""
file = shared.Files(content=b"txt_content", file_name="test_file.txt")

result = pdf_utils.is_pdf(file)
result = pdf_utils.read_pdf(f.read())

assert result is False
assert result is not None


def test_unit_is_pdf_invalid_pdf():
"""Test is pdf method returns False for file with invalid pdf content."""
file = shared.Files(content=b"invalid_pdf_content", file_name="test_file.pdf")

result = pdf_utils.is_pdf(file)

assert result is False


def test_unit_is_pdf_invalid_pdf_without_file_extension():
"""Test is pdf method returns False for file with invalid pdf content without basing on file extension."""
file = shared.Files(content=b"invalid_pdf_content", file_name="uuid1234")

result = pdf_utils.is_pdf(file)
"""Test is pdf method returns False for file with invalid extension."""
result = pdf_utils.read_pdf(b"txt_content")

assert result is False
assert result is None


def test_unit_get_starting_page_number_missing_key():
Expand Down Expand Up @@ -388,7 +336,10 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
assert result == expected_result


async def _request_mock(fails: bool, content: str) -> requests.Response:
async def _request_mock(
async_client: httpx.AsyncClient, # not used by mock
fails: bool,
content: str) -> requests.Response:
response = requests.Response()
response.status_code = 500 if fails else 200
response._content = content.encode()
Expand All @@ -399,40 +350,40 @@ async def _request_mock(fails: bool, content: str) -> requests.Response:
("allow_failed", "tasks", "expected_responses"), [
pytest.param(
True, [
_request_mock(fails=False, content="1"),
_request_mock(fails=False, content="2"),
_request_mock(fails=False, content="3"),
_request_mock(fails=False, content="4"),
partial(_request_mock, fails=False, content="1"),
partial(_request_mock, fails=False, content="2"),
partial(_request_mock, fails=False, content="3"),
partial(_request_mock, fails=False, content="4"),
],
["1", "2", "3", "4"],
id="no failures, fails allower"
),
pytest.param(
True, [
_request_mock(fails=False, content="1"),
_request_mock(fails=True, content="2"),
_request_mock(fails=False, content="3"),
_request_mock(fails=True, content="4"),
partial(_request_mock, fails=False, content="1"),
partial(_request_mock, fails=True, content="2"),
partial(_request_mock, fails=False, content="3"),
partial(_request_mock, fails=True, content="4"),
],
["1", "2", "3", "4"],
id="failures, fails allowed"
),
pytest.param(
False, [
_request_mock(fails=True, content="failure"),
_request_mock(fails=False, content="2"),
_request_mock(fails=True, content="failure"),
_request_mock(fails=False, content="4"),
partial(_request_mock, fails=True, content="failure"),
partial(_request_mock, fails=False, content="2"),
partial(_request_mock, fails=True, content="failure"),
partial(_request_mock, fails=False, content="4"),
],
["failure"],
id="failures, fails disallowed"
),
pytest.param(
False, [
_request_mock(fails=False, content="1"),
_request_mock(fails=False, content="2"),
_request_mock(fails=False, content="3"),
_request_mock(fails=False, content="4"),
partial(_request_mock, fails=False, content="1"),
partial(_request_mock, fails=False, content="2"),
partial(_request_mock, fails=False, content="3"),
partial(_request_mock, fails=False, content="4"),
],
["1", "2", "3", "4"],
id="no failures, fails disallowed"
Expand All @@ -451,14 +402,18 @@ async def test_unit_disallow_failed_coroutines(
assert response_contents == expected_responses


async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: Counter):
async def _fetch_canceller_error(
async_client: httpx.AsyncClient, # not used by mock
fails: bool,
content: str,
cancelled_counter: Counter):
try:
if not fails:
await asyncio.sleep(0.01)
print("Doesn't fail")
else:
print("Fails")
return await _request_mock(fails=fails, content=content)
return await _request_mock(async_client=async_client, fails=fails, content=content)
except asyncio.CancelledError:
cancelled_counter.update(["cancelled"])
print(cancelled_counter["cancelled"])
Expand All @@ -469,8 +424,8 @@ async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: C
async def test_remaining_tasks_cancelled_when_fails_disallowed():
cancelled_counter = Counter()
tasks = [
_fetch_canceller_error(fails=True, content="1", cancelled_counter=cancelled_counter),
*[_fetch_canceller_error(fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
partial(_fetch_canceller_error, fails=True, content="1", cancelled_counter=cancelled_counter),
*[partial(_fetch_canceller_error, fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
for i in range(2, 200)],
]

Expand Down
Loading