aboutsummaryrefslogtreecommitdiff
path: root/src/wikiget/dl.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/wikiget/dl.py')
-rw-r--r--src/wikiget/dl.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py
index 60405d0..c9b2ed5 100644
--- a/src/wikiget/dl.py
+++ b/src/wikiget/dl.py
@@ -20,7 +20,7 @@ import sys
from argparse import Namespace
from concurrent.futures import ThreadPoolExecutor
-from mwclient import APIError, InvalidResponse, LoginError
+from mwclient import APIError, InvalidResponse, LoginError, Site
from requests import ConnectionError, HTTPError
from tqdm import tqdm
@@ -38,10 +38,6 @@ logger = logging.getLogger(__name__)
def prep_download(dl: str, args: Namespace) -> File:
"""Prepare to download a file by parsing the filename or URL and CLI arguments.
- First, the target is parsed for a valid name, destination, and site. If there are no
- problems creating a File with this information, we connect to the site hosting it
- and fetch the relevant Image object, which is added as an attribute to the File.
-
:param dl: a string representing the file or URL to download
:type dl: str
:param args: command-line arguments and their values
@@ -57,8 +53,6 @@ def prep_download(dl: str, args: Namespace) -> File:
msg = f"[{file.dest}] File already exists; skipping download (use -f to force)"
raise FileExistsError(msg)
- site = connect_to_site(file.site, args)
- file.image = query_api(file.name, site)
return file
@@ -96,6 +90,8 @@ def process_download(args: Namespace) -> int:
# single download mode
try:
file = prep_download(args.FILE, args)
+ site = connect_to_site(file.site, args)
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.error(e)
exit_code = 1
@@ -135,11 +131,20 @@ def batch_download(args: Namespace) -> int:
# TODO: validate file contents before download process starts
with ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
+ site: Site = None
for line_num, line in dl_dict.items():
# keep track of batch file line numbers for debugging/logging purposes
logger.info("Processing '%s' at line %i", line, line_num)
try:
file = prep_download(line, args)
+ # if there's already a Site object matching the desired host, reuse it
+ # to reduce the number of API calls made per file
+ if not site or site.host != file.site:
+ logger.debug("Made a new site connection")
+ site = connect_to_site(file.site, args)
+ else:
+ logger.debug("Reused an existing site connection")
+ file.image = query_api(file.name, site)
except ParseError as e:
logger.warning("%s (line %i)", str(e), line_num)
errors += 1