aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/wikiget/dl.py19
-rw-r--r--tests/test_dl.py67
2 files changed, 61 insertions, 25 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py
index 722bf62..9b53d66 100644
--- a/src/wikiget/dl.py
+++ b/src/wikiget/dl.py
@@ -20,7 +20,7 @@ import sys
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor
-from mwclient import APIError, InvalidResponse, LoginError
+from mwclient import APIError, InvalidResponse, LoginError, Site
from requests import ConnectionError, HTTPError
from tqdm import tqdm
@@ -38,10 +38,6 @@ logger = logging.getLogger(__name__)
def prep_download(dl: str, args: Namespace) -> File:
"""Prepare to download a file by parsing the filename or URL and CLI arguments.
- First, the target is parsed for a valid name, destination, and site. If there are no
- problems creating a File with this information, we connect to the site hosting it
- and fetch the relevant Image object, which is added as an attribute to the File.
-
:param dl: a string representing the file or URL to download
:type dl: str
:param args: command-line arguments and their values
@@ -57,8 +53,6 @@ def prep_download(dl: str, args: Namespace) -> File:
msg = f"[{file.dest}] File already exists; skipping download (use -f to force)"
raise FileExistsError(msg)
- site = connect_to_site(file.site, args)
- file.image = query_api(file.name, site)
return file
@@ -96,6 +90,8 @@ def process_download(args: Namespace) -> int:
# single download mode
try:
file = prep_download(args.FILE, args)
+ site = connect_to_site(file.site, args)
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.error(e)
exit_code = 1
@@ -135,11 +131,20 @@ def batch_download(args: Namespace) -> int:
# TODO: validate file contents before download process starts
with ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
+ site: Site = None
for line_num, line in dl_dict.items():
# keep track of batch file line numbers for debugging/logging purposes
logger.info("Processing '%s' at line %i", line, line_num)
try:
file = prep_download(line, args)
+ # if there's already a Site object matching the desired host, reuse it
+ # to reduce the number of API calls made per file
+ if not site or site.host != file.site:
+ logger.debug("Made a new site connection")
+ site = connect_to_site(file.site, args)
+ else:
+ logger.debug("Reused an existing site connection")
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.warning("%s (line %i)", str(e), line_num)
errors += 1
diff --git a/tests/test_dl.py b/tests/test_dl.py
index e062116..10f5fd5 100644
--- a/tests/test_dl.py
+++ b/tests/test_dl.py
@@ -34,20 +34,9 @@ from wikiget.wikiget import parse_args
class TestPrepDownload:
"""Define tests related to wikiget.dl.prep_download."""
- @patch("wikiget.dl.query_api")
- @patch("wikiget.dl.connect_to_site")
- def test_prep_download(
- self, mock_connect_to_site: MagicMock, mock_query_api: MagicMock
- ) -> None:
+ def test_prep_download(self) -> None:
"""The prep_download function should create the expected file object."""
- mock_site = Mock()
- mock_image = Mock()
-
- mock_connect_to_site.return_value = mock_site
- mock_query_api.return_value = mock_image
-
expected_file = File(name="Example.jpg")
- expected_file.image = mock_image
args = parse_args(["File:Example.jpg"])
file = prep_download(args.FILE, args)
@@ -104,7 +93,8 @@ class TestProcessDownload:
mock_prep_download.return_value = File("Example.jpg")
args = parse_args(["File:Example.jpg"])
- exit_code = process_download(args)
+ with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"):
+ exit_code = process_download(args)
assert exit_code == 0
@@ -118,7 +108,8 @@ class TestProcessDownload:
mock_prep_download.return_value = File("Example.jpg")
args = parse_args(["File:Example.jpg"])
- exit_code = process_download(args)
+ with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"):
+ exit_code = process_download(args)
assert exit_code == 1
@@ -168,12 +159,10 @@ class TestBatchDownload:
"""Define tests related to wikiget.dl.batch_download."""
@patch("wikiget.dl.download")
- @patch("wikiget.dl.prep_download")
@patch("wikiget.dl.read_batch_file")
def test_batch_download(
self,
mock_read_batch_file: MagicMock,
- mock_prep_download: MagicMock,
mock_download: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
@@ -189,16 +178,58 @@ class TestBatchDownload:
mock_download.return_value = 0
args = parse_args(["-a", "batch.txt"])
- errors = batch_download(args)
+ with patch("wikiget.dl.query_api"), patch("wikiget.dl.connect_to_site"), patch(
+ "wikiget.dl.prep_download"
+ ):
+ errors = batch_download(args)
assert mock_read_batch_file.called
- assert mock_prep_download.called
assert mock_download.called
assert caplog.record_tuples == [
("wikiget.dl", logging.INFO, "Processing 'File:Example.jpg' at line 1")
]
assert errors == 0
+ @patch("wikiget.dl.connect_to_site")
+ @patch("wikiget.dl.prep_download")
+ @patch("wikiget.dl.read_batch_file")
+ def test_batch_download_reuse_site(
+ self,
+ mock_read_batch_file: MagicMock,
+ mock_prep_download: MagicMock,
+ mock_connect_to_site: MagicMock,
+ caplog: pytest.LogCaptureFixture,
+ ) -> None:
+ """Test that an existing site object is reused."""
+ caplog.set_level(logging.DEBUG)
+
+ mock_site = MagicMock()
+ mock_site.host = "commons.wikimedia.org"
+ mock_read_batch_file.return_value = {
+ 1: "File:Example.jpg",
+ 2: "File:Foobar.jpg",
+ }
+ mock_prep_download.return_value = File("Example.jpg")
+ mock_connect_to_site.return_value = mock_site
+
+ args = parse_args(["-a", "batch.txt"])
+ with patch("wikiget.dl.download"), patch("wikiget.dl.query_api"):
+ _ = batch_download(args)
+
+ assert mock_read_batch_file.called
+ assert mock_prep_download.called
+ assert mock_connect_to_site.called
+ assert caplog.record_tuples[1] == (
+ "wikiget.dl",
+ logging.DEBUG,
+ "Made a new site connection",
+ )
+ assert caplog.record_tuples[3] == (
+ "wikiget.dl",
+ logging.DEBUG,
+ "Reused an existing site connection",
+ )
+
@patch("wikiget.dl.read_batch_file")
def test_batch_download_os_error(
self, mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture