aboutsummaryrefslogtreecommitdiff
path: root/src/wikiget/dl.py
blob: 83aef9f48cf2f7ca9175f8982c9fb3e9e7cdcf8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# wikiget - CLI tool for downloading files from Wikimedia sites
# Copyright (C) 2018-2023 Cody Logan and contributors
# SPDX-License-Identifier: GPL-3.0-or-later
#
# Wikiget is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Wikiget is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Wikiget. If not, see <https://www.gnu.org/licenses/>.

import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor

from mwclient import APIError, InvalidResponse, LoginError, Site
from requests import ConnectionError, HTTPError
from tqdm import tqdm

import wikiget
from wikiget.exceptions import ParseError
from wikiget.file import File
from wikiget.parse import get_dest
from wikiget.validations import verify_hash


def query_api(filename, site_name, args):
    # connect to site and identify ourselves
    logging.info(f"Connecting to {site_name}")
    try:
        site = Site(site_name, path=args.path, clients_useragent=wikiget.USER_AGENT)
        if args.username and args.password:
            site.login(args.username, args.password)
    except ConnectionError as e:
        # usually this means there is no such site, or there's no network connection,
        # though it could be a certificate problem
        logging.error("Could not connect to specified site")
        logging.debug(e)
        raise
    except HTTPError as e:
        # most likely a 403 forbidden or 404 not found error for api.php
        logging.error(
            "Could not find the specified wiki's api.php. Check the value of --path."
        )
        logging.debug(e)
        raise
    except (InvalidResponse, LoginError) as e:
        # InvalidResponse: site exists, but we couldn't communicate with the API
        # endpoint for some reason other than an HTTP error.
        # LoginError: missing or invalid credentials
        logging.error(e)
        raise

    # get info about the target file
    try:
        image = site.images[filename]
    except APIError as e:
        # an API error at this point likely means access is denied, which could happen
        # with a private wiki
        logging.error(
            "Access denied. Try providing credentials with --username and --password."
        )
        for i in e.args:
            logging.debug(i)
        raise

    return image


def prep_download(dl, args):
    file = get_dest(dl, args)
    file.image = query_api(file.name, file.site, args)
    return file


def batch_download(args):
    input_file = args.FILE
    dl_list = {}
    errors = 0

    logging.info(f"Using batch file '{input_file}'.")

    try:
        fd = open(input_file)
    except OSError as e:
        logging.error("File could not be read. The following error was encountered:")
        logging.error(e)
        sys.exit(1)
    else:
        with fd:
            # read the file into memory and process each line as we go
            for line_num, line in enumerate(fd, start=1):
                line_s = line.strip()
                # ignore blank lines and lines starting with "#" (for comments)
                if line_s and not line_s.startswith("#"):
                    dl_list[line_num] = line_s

    # TODO: validate file contents before download process starts
    with ThreadPoolExecutor(max_workers=args.threads) as executor:
        futures = []
        for line_num, line in dl_list.items():
            # keep track of batch file line numbers for debugging/logging purposes
            logging.info(f"Processing '{line}' at line {line_num}")
            try:
                file = prep_download(line, args)
            except ParseError as e:
                logging.warning(f"{e} (line {line_num})")
                errors += 1
                continue
            except (ConnectionError, HTTPError, InvalidResponse, LoginError, APIError):
                logging.warning(
                    f"Unable to download '{line}' (line {line_num}) due to an error"
                )
                errors += 1
                continue
            future = executor.submit(download, file, args)
            futures.append(future)
        # wait for downloads to finish
        for future in futures:
            errors += future.result()
    return errors


def download(f, args):
    file = f.image
    filename = f.name
    dest = f.dest
    site = file.site

    errors = 0

    if file.exists:
        # file exists either locally or at a common repository, like Wikimedia Commons
        file_url = file.imageinfo["url"]
        file_size = file.imageinfo["size"]
        file_sha1 = file.imageinfo["sha1"]

        filename_log = f"Downloading '{filename}' ({file_size} bytes) from {site.host}"
        if args.output:
            filename_log += f" to '{dest}'"
        logging.info(filename_log)
        logging.info(f"{file_url}")

        if os.path.isfile(dest) and not args.force:
            logging.warning(
                f"File '{dest}' already exists, skipping download (use -f to force)"
            )
            errors += 1
        else:
            try:
                fd = open(dest, "wb")
            except OSError as e:
                logging.error(
                    "File could not be written. The following error was encountered:"
                )
                logging.error(e)
                errors += 1
            else:
                # download the file(s)
                if args.verbose >= wikiget.STD_VERBOSE:
                    leave_bars = True
                else:
                    leave_bars = False
                with tqdm(
                    leave=leave_bars,
                    total=file_size,
                    unit="B",
                    unit_scale=True,
                    unit_divisor=wikiget.CHUNKSIZE,
                ) as progress_bar:
                    with fd:
                        res = site.connection.get(file_url, stream=True)
                        progress_bar.set_postfix(file=dest, refresh=False)
                        for chunk in res.iter_content(wikiget.CHUNKSIZE):
                            fd.write(chunk)
                            progress_bar.update(len(chunk))

            # verify file integrity and log details
            dl_sha1 = verify_hash(dest)

            logging.info(f"Remote file SHA1 is {file_sha1}")
            logging.info(f"Local file SHA1 is {dl_sha1}")
            if dl_sha1 == file_sha1:
                logging.info("Hashes match!")
                # at this point, we've successfully downloaded the file
                success_log = f"'{filename}' downloaded"
                if args.output:
                    success_log += f" to '{dest}'"
                logging.info(success_log)
            else:
                logging.error("Hash mismatch! Downloaded file may be corrupt.")
                errors += 1

    else:
        # no file information returned
        logging.error(f"Target '{filename}' does not appear to be a valid file")
        errors += 1

    return errors