diff --git a/tests/test_helper.py b/tests/test_helper.py index 66632a1..b1b1a1c 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -1,14 +1,26 @@ import binascii import os +import re +from typing import Dict +from urllib.parse import urlparse import pytest import requests +from pytest_socket import disable_socket from requests.models import Response from aqt import helper +from aqt.exceptions import ArchiveConnectionError, ArchiveDownloadError +from aqt.helper import getUrl from aqt.metadata import Version +@pytest.fixture(autouse=True) +def disable_sockets(): + # This blocks all network connections, causing test failure if we used monkeypatch wrong + disable_socket() + + def test_helper_altlink(monkeypatch): class Message: headers = {"content-type": "text/plain", "length": 300} @@ -91,6 +103,60 @@ def test_helper_downloadBinary_sha256(tmp_path, monkeypatch): helper.downloadBinaryFile("http://example.com/test.xml", out, "sha256", expected, 60) +@pytest.mark.parametrize( + "mock_exception, expected_err_msg", + ( + (requests.exceptions.ConnectionError("Connection failed!"), "Connection error: ('Connection failed!',)"), + (requests.exceptions.Timeout("Connection timed out!"), "Connection timeout: ('Connection timed out!',)"), + ), +) +def test_helper_downloadBinary_connection_err(tmp_path, monkeypatch, mock_exception, expected_err_msg): + def _mock_get_conn_error(*args, **kwargs): + raise mock_exception + + monkeypatch.setattr(requests.Session, "get", _mock_get_conn_error) + + expected = binascii.unhexlify("1d41a93e4a585bb01e4518d4af431933") + out = tmp_path.joinpath("text.xml") + with pytest.raises(ArchiveConnectionError) as e: + helper.downloadBinaryFile("http://example.com/test.xml", out, "md5", expected, 60) + assert e.type == ArchiveConnectionError + assert format(e.value) == expected_err_msg + + +def test_helper_downloadBinary_wrong_checksum(tmp_path, monkeypatch): + monkeypatch.setattr(requests.Session, "get", mocked_requests_get) + + actual_hash = binascii.unhexlify("1d41a93e4a585bb01e4518d4af431933") + wrong_hash = binascii.unhexlify("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + expected_err = f"Download file is corrupted! Detect checksum error.\nExpected {wrong_hash}, Actual {actual_hash}" + out = tmp_path.joinpath("text.xml") + with pytest.raises(ArchiveDownloadError) as e: + helper.downloadBinaryFile("http://example.com/test.xml", out, "md5", wrong_hash, 60) + assert e.type == ArchiveDownloadError + assert format(e.value) == expected_err + + +def test_helper_downloadBinary_response_error_undefined(tmp_path, monkeypatch): + def iter_broken_content(*args, **kwargs): + raise RuntimeError("This chunk of downloaded content contains an error.") + + def mock_requests_get(*args, **kwargs): + response = Response() + response.status_code = 200 + response.iter_content = iter_broken_content + return response + + monkeypatch.setattr(requests.Session, "get", mock_requests_get) + + expected = binascii.unhexlify("1d41a93e4a585bb01e4518d4af431933") + out = tmp_path.joinpath("text.xml") + with pytest.raises(ArchiveDownloadError) as e: + helper.downloadBinaryFile("http://example.com/test.xml", out, "md5", expected, 60) + assert e.type == ArchiveDownloadError + assert format(e.value) == "Download error: This chunk of downloaded content contains an error." + + @pytest.mark.parametrize( "version, expect", [ @@ -108,3 +174,90 @@ def test_helper_downloadBinary_sha256(tmp_path, monkeypatch): ) def test_helper_to_version_permissive(version, expect): assert Version.permissive(version) == expect + + +def mocked_request_response_class(num_redirects: int = 0, forbidden_baseurls=None): + if not forbidden_baseurls: + forbidden_hostnames = [] + else: + forbidden_hostnames = [urlparse(host).hostname for host in forbidden_baseurls] + + class MockResponse: + redirects_for_host = {} + + def __init__(self, url: str, headers: Dict, text: str): + self.url = url + self.headers = {key: value for key, value in headers.items()} + + hostname = urlparse(url).hostname + if hostname not in MockResponse.redirects_for_host: + MockResponse.redirects_for_host[hostname] = num_redirects + + if MockResponse.redirects_for_host[hostname] > 0: + MockResponse.redirects_for_host[hostname] -= 1 + self.status_code = 302 + self.headers["Location"] = f"{url}/redirect{MockResponse.redirects_for_host[hostname]}" + self.text = f"Still {MockResponse.redirects_for_host[hostname]} redirects to go..." + self.reason = "Redirect" + elif hostname in forbidden_hostnames: + raise requests.exceptions.ConnectionError() + else: + self.status_code = 200 + self.text = text + + return MockResponse + + +def test_helper_getUrl_ok(monkeypatch): + response_class = mocked_request_response_class() + + def _mock_get(url, **kwargs): + return response_class(url, {}, "some_html_content") + + monkeypatch.setattr(requests, "get", _mock_get) + assert getUrl("some_url", timeout=(5, 5)) == "some_html_content" + + +def mock_get_redirect(num_redirects: int): + response_class = mocked_request_response_class(num_redirects) + + def _mock(url: str, timeout, allow_redirects): + return response_class(url, {}, text="some_html_content") + + def _mock_session(self, url: str, timeout, stream): + return response_class(url, {}, text="some_html_content") + + return _mock, _mock_session + + +def test_helper_getUrl_redirect_5(monkeypatch): + mocked_get, mocked_session_get = mock_get_redirect(num_redirects=5) + monkeypatch.setattr(requests, "get", mocked_get) + monkeypatch.setattr(requests.Session, "get", mocked_session_get) + assert getUrl("some_url", (5, 5)) == "some_html_content" + + +def test_helper_getUrl_redirect_too_many(monkeypatch): + mocked_get, mocked_session_get = mock_get_redirect(num_redirects=11) + monkeypatch.setattr(requests, "get", mocked_get) + monkeypatch.setattr(requests.Session, "get", mocked_session_get) + with pytest.raises(ArchiveDownloadError) as e: + getUrl("some_url", (5, 5)) + assert e.type == ArchiveDownloadError + + +def test_helper_getUrl_conn_error(monkeypatch): + response_class = mocked_request_response_class(forbidden_baseurls=["https://www.forbidden.com"]) + url = "https://www.forbidden.com/some_path" + timeout = (5, 5) + + expect_re = re.compile(r"^Failure to connect to.+" + re.escape(url)) + + def _mock(url: str, *args, **kwargs): + return response_class(url, {}, text="some_html_content") + + monkeypatch.setattr(requests, "get", _mock) + with pytest.raises(ArchiveConnectionError) as e: + getUrl(url, timeout) + assert e.type == ArchiveConnectionError + assert expect_re.match(format(e.value))