diff options
| -rw-r--r-- | src/wikiget/dl.py | 19 | ||||
| -rw-r--r-- | tests/test_dl.py | 67 |
2 files changed, 61 insertions, 25 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py index 60405d0..c9b2ed5 100644 --- a/src/wikiget/dl.py +++ b/src/wikiget/dl.py @@ -20,7 +20,7 @@ import sys from argparse import Namespace from concurrent.futures import ThreadPoolExecutor -from mwclient import APIError, InvalidResponse, LoginError +from mwclient import APIError, InvalidResponse, LoginError, Site from requests import ConnectionError, HTTPError from tqdm import tqdm @@ -38,10 +38,6 @@ logger = logging.getLogger(__name__) def prep_download(dl: str, args: Namespace) -> File: """Prepare to download a file by parsing the filename or URL and CLI arguments. - First, the target is parsed for a valid name, destination, and site. If there are no - problems creating a File with this information, we connect to the site hosting it - and fetch the relevant Image object, which is added as an attribute to the File. - :param dl: a string representing the file or URL to download :type dl: str :param args: command-line arguments and their values @@ -57,8 +53,6 @@ def prep_download(dl: str, args: Namespace) -> File: msg = f"[{file.dest}] File already exists; skipping download (use -f to force)" raise FileExistsError(msg) - site = connect_to_site(file.site, args) - file.image = query_api(file.name, site) return file @@ -96,6 +90,8 @@ def process_download(args: Namespace) -> int: # single download mode try: file = prep_download(args.FILE, args) + site = connect_to_site(file.site, args) + file.image = query_api(file.name, site) except ParseError as e: logger.error(e) exit_code = 1 @@ -135,11 +131,20 @@ def batch_download(args: Namespace) -> int: # TODO: validate file contents before download process starts with ThreadPoolExecutor(max_workers=args.threads) as executor: futures = [] + site: Site = None for line_num, line in dl_dict.items(): # keep track of batch file line numbers for debugging/logging purposes logger.info("Processing '%s' at line %i", line, line_num) try: file = prep_download(line, args) + # if there's already a Site object matching the desired host, reuse it + # to reduce the number of API calls made per file + if not site or site.host != file.site: + logger.debug("Made a new site connection") + site = connect_to_site(file.site, args) + else: + logger.debug("Reused an existing site connection") + file.image = query_api(file.name, site) except ParseError as e: logger.warning("%s (line %i)", str(e), line_num) errors += 1 diff --git a/tests/test_dl.py b/tests/test_dl.py index cb893b8..b0822cf 100644 --- a/tests/test_dl.py +++ b/tests/test_dl.py @@ -34,20 +34,9 @@ from wikiget.wikiget import parse_args class TestPrepDownload: """Define tests related to wikiget.dl.prep_download.""" - @patch("wikiget.dl.query_api") - @patch("wikiget.dl.connect_to_site") - def test_prep_download( - self, mock_connect_to_site: MagicMock, mock_query_api: MagicMock - ) -> None: + def test_prep_download(self) -> None: """The prep_download function should create the expected file object.""" - mock_site = Mock() - mock_image = Mock() - - mock_connect_to_site.return_value = mock_site - mock_query_api.return_value = mock_image - expected_file = File(name="Example.jpg") - expected_file.image = mock_image args = parse_args(["File:Example.jpg"]) file = prep_download(args.FILE, args) @@ -104,7 +93,8 @@ class TestProcessDownload: mock_prep_download.return_value = File("Example.jpg") args = parse_args(["File:Example.jpg"]) - exit_code = process_download(args) + with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"): + exit_code = process_download(args) assert exit_code == 0 @@ -118,7 +108,8 @@ class TestProcessDownload: mock_prep_download.return_value = File("Example.jpg") args = parse_args(["File:Example.jpg"]) - exit_code = process_download(args) + with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"): + exit_code = process_download(args) assert exit_code == 1 @@ -168,12 +159,10 @@ class TestBatchDownload: """Define tests related to wikiget.dl.batch_download.""" @patch("wikiget.dl.download") - @patch("wikiget.dl.prep_download") @patch("wikiget.dl.read_batch_file") def test_batch_download( self, mock_read_batch_file: MagicMock, - mock_prep_download: MagicMock, mock_download: MagicMock, caplog: pytest.LogCaptureFixture, ) -> None: @@ -189,16 +178,58 @@ class TestBatchDownload: mock_download.return_value = 0 args = parse_args(["-a", "batch.txt"]) - errors = batch_download(args) + with patch("wikiget.dl.query_api"), patch("wikiget.dl.connect_to_site"), patch( + "wikiget.dl.prep_download" + ): + errors = batch_download(args) assert mock_read_batch_file.called - assert mock_prep_download.called assert mock_download.called assert caplog.record_tuples == [ ("wikiget.dl", logging.INFO, "Processing 'File:Example.jpg' at line 1") ] assert errors == 0 + @patch("wikiget.dl.connect_to_site") + @patch("wikiget.dl.prep_download") + @patch("wikiget.dl.read_batch_file") + def test_batch_download_reuse_site( + self, + mock_read_batch_file: MagicMock, + mock_prep_download: MagicMock, + mock_connect_to_site: MagicMock, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Test that an existing site object is reused.""" + caplog.set_level(logging.DEBUG) + + mock_site = MagicMock() + mock_site.host = "commons.wikimedia.org" + mock_read_batch_file.return_value = { + 1: "File:Example.jpg", + 2: "File:Foobar.jpg", + } + mock_prep_download.return_value = File("Example.jpg") + mock_connect_to_site.return_value = mock_site + + args = parse_args(["-a", "batch.txt"]) + with patch("wikiget.dl.download"), patch("wikiget.dl.query_api"): + _ = batch_download(args) + + assert mock_read_batch_file.called + assert mock_prep_download.called + assert mock_connect_to_site.called + assert caplog.record_tuples[1] == ( + "wikiget.dl", + logging.DEBUG, + "Made a new site connection", + ) + assert caplog.record_tuples[3] == ( + "wikiget.dl", + logging.DEBUG, + "Reused an existing site connection", + ) + @patch("wikiget.dl.read_batch_file") def test_batch_download_os_error( self, mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture |
