aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCody Logan <cody@lokken.dev>2023-11-16 10:23:49 -0800
committerCody Logan <cody@lokken.dev>2023-11-16 10:23:49 -0800
commit2f3074e1b2a62cbd5e32778abc0ff82027c1ce3b (patch)
treec8e7fc5662cf1a34a5b2ecbd8b2bbbbaa630b19c
parentce58a03caa6f4d9e3cb01898b4b73716031b24dd (diff)
downloadwikiget-2f3074e1b2a62cbd5e32778abc0ff82027c1ce3b.tar.gz
wikiget-2f3074e1b2a62cbd5e32778abc0ff82027c1ce3b.zip
Reuse existing Site object when possible in batch downloads
Previously, every file downloaded in a batch would create a new Site object. Now, the Site object created by the first file will be reused by subsequent files if it matches the file's requested host, which will significantly speed up the download process, assuming all files are from the same site. This is a quick and dirty fix which could be improved to better handle situations where there are a mix of files from different sites.
-rw-r--r--src/wikiget/dl.py19
-rw-r--r--tests/test_dl.py67
2 files changed, 61 insertions, 25 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py
index 60405d0..c9b2ed5 100644
--- a/src/wikiget/dl.py
+++ b/src/wikiget/dl.py
@@ -20,7 +20,7 @@ import sys
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor
-from mwclient import APIError, InvalidResponse, LoginError
+from mwclient import APIError, InvalidResponse, LoginError, Site
from requests import ConnectionError, HTTPError
from tqdm import tqdm
@@ -38,10 +38,6 @@ logger = logging.getLogger(__name__)
def prep_download(dl: str, args: Namespace) -> File:
"""Prepare to download a file by parsing the filename or URL and CLI arguments.
- First, the target is parsed for a valid name, destination, and site. If there are no
- problems creating a File with this information, we connect to the site hosting it
- and fetch the relevant Image object, which is added as an attribute to the File.
-
:param dl: a string representing the file or URL to download
:type dl: str
:param args: command-line arguments and their values
@@ -57,8 +53,6 @@ def prep_download(dl: str, args: Namespace) -> File:
msg = f"[{file.dest}] File already exists; skipping download (use -f to force)"
raise FileExistsError(msg)
- site = connect_to_site(file.site, args)
- file.image = query_api(file.name, site)
return file
@@ -96,6 +90,8 @@ def process_download(args: Namespace) -> int:
# single download mode
try:
file = prep_download(args.FILE, args)
+ site = connect_to_site(file.site, args)
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.error(e)
exit_code = 1
@@ -135,11 +131,20 @@ def batch_download(args: Namespace) -> int:
# TODO: validate file contents before download process starts
with ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
+ site: Site = None
for line_num, line in dl_dict.items():
# keep track of batch file line numbers for debugging/logging purposes
logger.info("Processing '%s' at line %i", line, line_num)
try:
file = prep_download(line, args)
+ # if there's already a Site object matching the desired host, reuse it
+ # to reduce the number of API calls made per file
+ if not site or site.host != file.site:
+ logger.debug("Made a new site connection")
+ site = connect_to_site(file.site, args)
+ else:
+ logger.debug("Reused an existing site connection")
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.warning("%s (line %i)", str(e), line_num)
errors += 1
diff --git a/tests/test_dl.py b/tests/test_dl.py
index cb893b8..b0822cf 100644
--- a/tests/test_dl.py
+++ b/tests/test_dl.py
@@ -34,20 +34,9 @@ from wikiget.wikiget import parse_args
class TestPrepDownload:
"""Define tests related to wikiget.dl.prep_download."""
- @patch("wikiget.dl.query_api")
- @patch("wikiget.dl.connect_to_site")
- def test_prep_download(
- self, mock_connect_to_site: MagicMock, mock_query_api: MagicMock
- ) -> None:
+ def test_prep_download(self) -> None:
"""The prep_download function should create the expected file object."""
- mock_site = Mock()
- mock_image = Mock()
-
- mock_connect_to_site.return_value = mock_site
- mock_query_api.return_value = mock_image
-
expected_file = File(name="Example.jpg")
- expected_file.image = mock_image
args = parse_args(["File:Example.jpg"])
file = prep_download(args.FILE, args)
@@ -104,7 +93,8 @@ class TestProcessDownload:
mock_prep_download.return_value = File("Example.jpg")
args = parse_args(["File:Example.jpg"])
- exit_code = process_download(args)
+ with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"):
+ exit_code = process_download(args)
assert exit_code == 0
@@ -118,7 +108,8 @@ class TestProcessDownload:
mock_prep_download.return_value = File("Example.jpg")
args = parse_args(["File:Example.jpg"])
- exit_code = process_download(args)
+ with patch("wikiget.dl.connect_to_site"), patch("wikiget.dl.query_api"):
+ exit_code = process_download(args)
assert exit_code == 1
@@ -168,12 +159,10 @@ class TestBatchDownload:
"""Define tests related to wikiget.dl.batch_download."""
@patch("wikiget.dl.download")
- @patch("wikiget.dl.prep_download")
@patch("wikiget.dl.read_batch_file")
def test_batch_download(
self,
mock_read_batch_file: MagicMock,
- mock_prep_download: MagicMock,
mock_download: MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
@@ -189,16 +178,58 @@ class TestBatchDownload:
mock_download.return_value = 0
args = parse_args(["-a", "batch.txt"])
- errors = batch_download(args)
+ with patch("wikiget.dl.query_api"), patch("wikiget.dl.connect_to_site"), patch(
+ "wikiget.dl.prep_download"
+ ):
+ errors = batch_download(args)
assert mock_read_batch_file.called
- assert mock_prep_download.called
assert mock_download.called
assert caplog.record_tuples == [
("wikiget.dl", logging.INFO, "Processing 'File:Example.jpg' at line 1")
]
assert errors == 0
+ @patch("wikiget.dl.connect_to_site")
+ @patch("wikiget.dl.prep_download")
+ @patch("wikiget.dl.read_batch_file")
+ def test_batch_download_reuse_site(
+ self,
+ mock_read_batch_file: MagicMock,
+ mock_prep_download: MagicMock,
+ mock_connect_to_site: MagicMock,
+ caplog: pytest.LogCaptureFixture,
+ ) -> None:
+ """Test that an existing site object is reused."""
+ caplog.set_level(logging.DEBUG)
+
+ mock_site = MagicMock()
+ mock_site.host = "commons.wikimedia.org"
+ mock_read_batch_file.return_value = {
+ 1: "File:Example.jpg",
+ 2: "File:Foobar.jpg",
+ }
+ mock_prep_download.return_value = File("Example.jpg")
+ mock_connect_to_site.return_value = mock_site
+
+ args = parse_args(["-a", "batch.txt"])
+ with patch("wikiget.dl.download"), patch("wikiget.dl.query_api"):
+ _ = batch_download(args)
+
+ assert mock_read_batch_file.called
+ assert mock_prep_download.called
+ assert mock_connect_to_site.called
+ assert caplog.record_tuples[1] == (
+ "wikiget.dl",
+ logging.DEBUG,
+ "Made a new site connection",
+ )
+ assert caplog.record_tuples[3] == (
+ "wikiget.dl",
+ logging.DEBUG,
+ "Reused an existing site connection",
+ )
+
@patch("wikiget.dl.read_batch_file")
def test_batch_download_os_error(
self, mock_read_batch_file: MagicMock, caplog: pytest.LogCaptureFixture