diff options
| -rw-r--r-- | pyproject.toml | 3 | ||||
| -rw-r--r-- | src/wikiget/client.py | 10 | ||||
| -rw-r--r-- | src/wikiget/dl.py | 39 | ||||
| -rw-r--r-- | src/wikiget/file.py | 8 | ||||
| -rw-r--r-- | tests/conftest.py | 10 | ||||
| -rw-r--r-- | tests/test_client.py | 2 | ||||
| -rw-r--r-- | tests/test_dl.py | 155 | ||||
| -rw-r--r-- | tests/test_file_class.py | 2 | ||||
| -rw-r--r-- | tests/test_logging.py | 8 | ||||
| -rw-r--r-- | tests/test_wikiget_cli.py | 2 |
10 files changed, 168 insertions, 71 deletions
diff --git a/pyproject.toml b/pyproject.toml index 5681dc4..46f1ddf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,10 +79,12 @@ cov-report = [ "coverage report", ] htmlcov = "coverage html" +xmlcov = "coverage xml" cov = [ "test-cov", "cov-report", "htmlcov", + "xmlcov", ] [[tool.hatch.envs.all.matrix]] @@ -198,6 +200,7 @@ warn_unused_ignores = true strict_equality = true check_untyped_defs = true extra_checks = true +disallow_untyped_defs = true [[tool.mypy.overrides]] module = [ diff --git a/src/wikiget/client.py b/src/wikiget/client.py index 3a6e40f..2729144 100644 --- a/src/wikiget/client.py +++ b/src/wikiget/client.py @@ -17,15 +17,21 @@ """Handle API calls (via mwclient) for site and image information.""" +from __future__ import annotations + import logging -from argparse import Namespace +from typing import TYPE_CHECKING from mwclient import APIError, InvalidResponse, LoginError, Site -from mwclient.image import Image from requests import ConnectionError, HTTPError import wikiget +if TYPE_CHECKING: + from argparse import Namespace + + from mwclient.image import Image + logger = logging.getLogger(__name__) diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py index 5e2a6cb..85c6685 100644 --- a/src/wikiget/dl.py +++ b/src/wikiget/dl.py @@ -17,33 +17,35 @@ """Prepare and process file downloads.""" +from __future__ import annotations + import logging import sys -from argparse import Namespace from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING -from mwclient import APIError, InvalidResponse, LoginError +from mwclient import APIError, InvalidResponse, LoginError, Site from requests import ConnectionError, HTTPError from tqdm import tqdm import wikiget from wikiget.client import connect_to_site, query_api from wikiget.exceptions import ParseError -from wikiget.file import File from wikiget.logging import FileLogAdapter from wikiget.parse import get_dest, read_batch_file from wikiget.validations import verify_hash +if TYPE_CHECKING: + from argparse import Namespace + + from wikiget.file import File + logger = logging.getLogger(__name__) def prep_download(dl: str, args: Namespace) -> File: """Prepare to download a file by parsing the filename or URL and CLI arguments. - First, the target is parsed for a valid name, destination, and site. If there are no - problems creating a File with this information, we connect to the site hosting it - and fetch the relevant Image object, which is added as an attribute to the File. - :param dl: a string representing the file or URL to download :type dl: str :param args: command-line arguments and their values @@ -59,8 +61,6 @@ def prep_download(dl: str, args: Namespace) -> File: msg = f"[{file.dest}] File already exists; skipping download (use -f to force)" raise FileExistsError(msg) - site = connect_to_site(file.site, args) - file.image = query_api(file.name, site) return file @@ -98,6 +98,8 @@ def process_download(args: Namespace) -> int: # single download mode try: file = prep_download(args.FILE, args) + site = connect_to_site(file.site, args) + file.image = query_api(file.name, site) except ParseError as e: logger.error(e) exit_code = 1 @@ -134,14 +136,31 @@ def batch_download(args: Namespace) -> int: logger.error("File could not be read: %s", str(e)) sys.exit(1) - # TODO: validate file contents before download process starts with ThreadPoolExecutor(max_workers=args.threads) as executor: futures = [] + sites: list[Site] = [] for line_num, line in dl_dict.items(): # keep track of batch file line numbers for debugging/logging purposes logger.info("Processing '%s' at line %i", line, line_num) try: file = prep_download(line, args) + site = next( + filter( + lambda site: site.host == file.site, + sites, + ), + None, + ) + # if there's already a Site object matching the desired host, reuse it + # to reduce the number of API calls made per file + if site: + logger.debug("Reusing the existing connection to %s", site.host) + else: + logger.debug("Making a new connection to %s", file.site) + site = connect_to_site(file.site, args) + # cache the new Site for reuse + sites.append(site) + file.image = query_api(file.name, site) except ParseError as e: logger.warning("%s (line %i)", str(e), line_num) errors += 1 diff --git a/src/wikiget/file.py b/src/wikiget/file.py index 36ce892..0b0c1e0 100644 --- a/src/wikiget/file.py +++ b/src/wikiget/file.py @@ -17,12 +17,16 @@ """Define a File class for representing individual files to be downloaded.""" -from pathlib import Path +from __future__ import annotations -from mwclient.image import Image +from pathlib import Path +from typing import TYPE_CHECKING from wikiget import DEFAULT_SITE +if TYPE_CHECKING: + from mwclient.image import Image + class File: """A file object.""" diff --git a/tests/conftest.py b/tests/conftest.py index 6088029..128b581 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,13 +17,19 @@ """Define fixtures used across all tests in this folder.""" -from pathlib import Path +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest -import requests_mock as rm from wikiget.file import File +if TYPE_CHECKING: + from pathlib import Path + + import requests_mock as rm + # 2x2 JPEG TEST_FILE_BYTES = ( b"\xff\xd8\xff\xdb\x00C\x00\x03\x02\x02\x02\x02\x02\x03\x02\x02\x02\x03\x03\x03\x03" diff --git a/tests/test_client.py b/tests/test_client.py index dae63f5..a0e4855 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,6 +17,8 @@ """Define tests related to the wikiget.client module.""" +from __future__ import annotations + import logging from unittest.mock import MagicMock, patch, sentinel diff --git a/tests/test_dl.py b/tests/test_dl.py index e062116..7117e50 100644 --- a/tests/test_dl.py +++ b/tests/test_dl.py @@ -17,8 +17,10 @@ """Define tests related to the wikiget.dl module.""" +from __future__ import annotations + import logging -from pathlib import Path +from typing import TYPE_CHECKING from unittest.mock import MagicMock, Mock, patch import pytest @@ -30,24 +32,16 @@ from wikiget.exceptions import ParseError from wikiget.file import File from wikiget.wikiget import parse_args +if TYPE_CHECKING: + from pathlib import Path + class TestPrepDownload: """Define tests related to wikiget.dl.prep_download.""" - @patch("wikiget.dl.query_api") - @patch("wikiget.dl.connect_to_site") - def test_prep_download( - self, mock_connect_to_site: MagicMock, mock_query_api: MagicMock - ) -> None: + def test_prep_download(self) -> None: """The prep_download function should create the expected file object.""" - mock_site = Mock() - mock_image = Mock() - - mock_connect_to_site.return_value = mock_site - mock_query_api.return_value = mock_image - expected_file = File(name="Example.jpg") - expected_file.image = mock_image args = parse_args(["File:Example.jpg"]) file = prep_download(args.FILE, args) @@ -104,7 +98,8 @@ class TestProcessDownload: mock_prep_download.return_value = File("Example.jpg") args = parse_args(["File:Example.jpg"]) - exit_code = process_download(args) + with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"): + exit_code = process_download(args) assert exit_code == 0 @@ -118,7 +113,8 @@ class TestProcessDownload: mock_prep_download.return_value = File("Example.jpg") args = parse_args(["File:Example.jpg"]) - exit_code = process_download(args) + with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"): + exit_code = process_download(args) assert exit_code == 1 @@ -164,17 +160,15 @@ class TestProcessDownload: assert exit_code == 1 +@patch("wikiget.dl.read_batch_file") class TestBatchDownload: """Define tests related to wikiget.dl.batch_download.""" @patch("wikiget.dl.download") - @patch("wikiget.dl.prep_download") - @patch("wikiget.dl.read_batch_file") def test_batch_download( self, - mock_read_batch_file: MagicMock, - mock_prep_download: MagicMock, mock_download: MagicMock, + mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture, ) -> None: """Test that no errors are returned for a successful batch download. @@ -189,17 +183,57 @@ class TestBatchDownload: mock_download.return_value = 0 args = parse_args(["-a", "batch.txt"]) - errors = batch_download(args) + with patch("wikiget.dl.query_api"), patch("wikiget.dl.connect_to_site"), patch( + "wikiget.dl.prep_download" + ): + errors = batch_download(args) assert mock_read_batch_file.called - assert mock_prep_download.called assert mock_download.called assert caplog.record_tuples == [ ("wikiget.dl", logging.INFO, "Processing 'File:Example.jpg' at line 1") ] assert errors == 0 - @patch("wikiget.dl.read_batch_file") + @patch("wikiget.dl.connect_to_site") + @patch("wikiget.dl.prep_download") + def test_batch_download_reuse_site( + self, + mock_prep_download: MagicMock, + mock_connect_to_site: MagicMock, + mock_read_batch_file: MagicMock, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Test that an existing site object is reused.""" + caplog.set_level(logging.DEBUG) + + mock_site = MagicMock() + mock_site.host = "commons.wikimedia.org" + mock_read_batch_file.return_value = { + 1: "File:Example.jpg", + 2: "File:Foobar.jpg", + } + mock_prep_download.return_value = File("Example.jpg") + mock_connect_to_site.return_value = mock_site + + args = parse_args(["-a", "batch.txt"]) + with patch("wikiget.dl.download"), patch("wikiget.dl.query_api"): + _ = batch_download(args) + + assert mock_read_batch_file.called + assert mock_prep_download.called + assert mock_connect_to_site.called + assert caplog.record_tuples[1] == ( + "wikiget.dl", + logging.DEBUG, + "Making a new connection to commons.wikimedia.org", + ) + assert caplog.record_tuples[3] == ( + "wikiget.dl", + logging.DEBUG, + "Reusing the existing connection to commons.wikimedia.org", + ) + def test_batch_download_os_error( self, mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture ) -> None: @@ -216,11 +250,10 @@ class TestBatchDownload: ] @patch("wikiget.dl.prep_download") - @patch("wikiget.dl.read_batch_file") def test_batch_download_parse_error( self, - mock_read_batch_file: MagicMock, mock_prep_download: MagicMock, + mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning log message is created if ParseError is raised. @@ -242,11 +275,10 @@ class TestBatchDownload: assert errors == 1 @patch("wikiget.dl.prep_download") - @patch("wikiget.dl.read_batch_file") def test_batch_download_file_exists_error( self, - mock_read_batch_file: MagicMock, mock_prep_download: MagicMock, + mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning log message is created if the download file exists.""" @@ -264,11 +296,10 @@ class TestBatchDownload: assert errors == 1 @patch("wikiget.dl.prep_download") - @patch("wikiget.dl.read_batch_file") def test_batch_download_other_error( self, - mock_read_batch_file: MagicMock, mock_prep_download: MagicMock, + mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning log message is created if there are problems downloading. @@ -317,7 +348,13 @@ class TestDownload: file.image.site.connection = requests.Session() return file - def test_download(self, mock_file: File, caplog: pytest.LogCaptureFixture) -> None: + @patch("wikiget.dl.verify_hash") + def test_download( + self, + mock_verify_hash: MagicMock, + mock_file: File, + caplog: pytest.LogCaptureFixture, + ) -> None: """Test that the correct log messages are created when downloading a file. There should be a series of info-level messages containing the filename, size, @@ -326,10 +363,10 @@ class TestDownload: """ caplog.set_level(logging.INFO) - with patch("wikiget.dl.verify_hash") as mock_verify_hash: - mock_verify_hash.return_value = "d01b79a6781c72ac9bfff93e5e2cfbeef4efc840" - args = parse_args(["File:Example.jpg"]) - errors = download(mock_file, args) + mock_verify_hash.return_value = "d01b79a6781c72ac9bfff93e5e2cfbeef4efc840" + + args = parse_args(["File:Example.jpg"]) + errors = download(mock_file, args) assert caplog.record_tuples == [ ( @@ -361,8 +398,12 @@ class TestDownload: ] assert errors == 0 + @patch("wikiget.dl.verify_hash") def test_download_with_output( - self, mock_file: File, caplog: pytest.LogCaptureFixture + self, + mock_verify_hash: MagicMock, + mock_file: File, + caplog: pytest.LogCaptureFixture, ) -> None: """Test that the correct log messages are created when downloading a file. @@ -371,11 +412,10 @@ class TestDownload: caplog.set_level(logging.INFO) tmp_file = mock_file.dest + mock_verify_hash.return_value = "d01b79a6781c72ac9bfff93e5e2cfbeef4efc840" - with patch("wikiget.dl.verify_hash") as mock_verify_hash: - mock_verify_hash.return_value = "d01b79a6781c72ac9bfff93e5e2cfbeef4efc840" - args = parse_args(["-o", str(tmp_file), "File:Example.jpg"]) - errors = download(mock_file, args) + args = parse_args(["-o", str(tmp_file), "File:Example.jpg"]) + errors = download(mock_file, args) assert caplog.record_tuples[0] == ( "wikiget.dl", @@ -405,18 +445,19 @@ class TestDownload: ] assert errors == 0 + @patch("pathlib.Path.open") def test_download_os_error( - self, mock_file: File, caplog: pytest.LogCaptureFixture + self, mock_open: MagicMock, mock_file: File, caplog: pytest.LogCaptureFixture ) -> None: """Test what happens when an OSError is raised during download. If the downloaded file cannot be created, an error log message should be created with details on the exception. """ - with patch("pathlib.Path.open") as mock_open: - mock_open.side_effect = OSError("write error") - args = parse_args(["File:Example.jpg"]) - errors = download(mock_file, args) + mock_open.side_effect = OSError("write error") + + args = parse_args(["File:Example.jpg"]) + errors = download(mock_file, args) assert caplog.record_tuples == [ ( @@ -427,18 +468,22 @@ class TestDownload: ] assert errors == 1 + @patch("wikiget.dl.verify_hash") def test_download_verify_os_error( - self, mock_file: File, caplog: pytest.LogCaptureFixture + self, + mock_verify_hash: MagicMock, + mock_file: File, + caplog: pytest.LogCaptureFixture, ) -> None: """Test what happens when an OSError is raised during verification. If the downloaded file cannot be read in order to calculate its hash, an error log message should be created with details on the exception. """ - with patch("wikiget.dl.verify_hash") as mock_verify_hash: - mock_verify_hash.side_effect = OSError("read error") - args = parse_args(["File:Example.jpg"]) - errors = download(mock_file, args) + mock_verify_hash.side_effect = OSError("read error") + + args = parse_args(["File:Example.jpg"]) + errors = download(mock_file, args) assert caplog.record_tuples == [ ( @@ -449,17 +494,21 @@ class TestDownload: ] assert errors == 1 + @patch("wikiget.dl.verify_hash") def test_download_verify_hash_mismatch( - self, mock_file: File, caplog: pytest.LogCaptureFixture + self, + mock_verify_hash: MagicMock, + mock_file: File, + caplog: pytest.LogCaptureFixture, ) -> None: """Test what happens when the downloaded file hash and server hash don't match. An error log message should be created if there's a hash mismatch. """ - with patch("wikiget.dl.verify_hash") as mock_verify_hash: - mock_verify_hash.return_value = "mismatch" - args = parse_args(["File:Example.jpg"]) - errors = download(mock_file, args) + mock_verify_hash.return_value = "mismatch" + + args = parse_args(["File:Example.jpg"]) + errors = download(mock_file, args) assert caplog.record_tuples == [ ( diff --git a/tests/test_file_class.py b/tests/test_file_class.py index 4ad06d1..699f40d 100644 --- a/tests/test_file_class.py +++ b/tests/test_file_class.py @@ -17,6 +17,8 @@ """Define tests related to the wikiget.file module.""" +from __future__ import annotations + from wikiget import DEFAULT_SITE from wikiget.file import File diff --git a/tests/test_logging.py b/tests/test_logging.py index 8d58cdf..a402120 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -17,14 +17,18 @@ """Define tests related to the wikiget.logging module.""" +from __future__ import annotations + import logging from pathlib import Path - -import pytest +from typing import TYPE_CHECKING from wikiget.logging import FileLogAdapter, configure_logging from wikiget.wikiget import parse_args +if TYPE_CHECKING: + import pytest + class TestLogging: """Define tests related to wikiget.logging.configure_logging and FileLogAdapter.""" diff --git a/tests/test_wikiget_cli.py b/tests/test_wikiget_cli.py index 15e2a7b..898ab8e 100644 --- a/tests/test_wikiget_cli.py +++ b/tests/test_wikiget_cli.py @@ -17,6 +17,8 @@ """Define tests related to the wikiget.wikiget module.""" +from __future__ import annotations + import logging from unittest.mock import MagicMock, patch |
