aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/wikiget/dl.py16
-rw-r--r--src/wikiget/wikiget.py3
-rw-r--r--tests/test_dl.py18
3 files changed, 31 insertions, 6 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py
index f210182..2fd7388 100644
--- a/src/wikiget/dl.py
+++ b/src/wikiget/dl.py
@@ -38,6 +38,12 @@ logger = logging.getLogger(__name__)
def prep_download(dl: str, args: Namespace) -> File:
file = get_dest(dl, args)
+
+ # check if the destination file already exists; don't overwrite unless the user says
+ if os.path.isfile(file.dest) and not args.force:
+ msg = f"[{file.dest}] File already exists; skipping download (use -f to force)"
+ raise FileExistsError(msg)
+
site = connect_to_site(file.site, args)
file.image = query_api(file.name, site)
return file
@@ -65,6 +71,10 @@ def batch_download(args: Namespace) -> int:
logger.warning(f"{e} (line {line_num})")
errors += 1
continue
+ except FileExistsError as e:
+ logger.warning(e)
+ errors += 1
+ continue
except (ConnectionError, HTTPError, InvalidResponse, LoginError, APIError):
logger.warning(
f"Unable to download '{line}' (line {line_num}) due to an error"
@@ -101,11 +111,7 @@ def download(f: File, args: Namespace) -> int:
adapter.info(filename_log)
adapter.info(f"{file_url}")
- if os.path.isfile(dest) and not args.force:
- # TODO: check for this before the download process starts
- adapter.warning("File already exists; skipping download (use -f to force)")
- errors += 1
- elif args.dry_run:
+ if args.dry_run:
adapter.warning("Dry run; download skipped")
else:
try:
diff --git a/src/wikiget/wikiget.py b/src/wikiget/wikiget.py
index 0b08f68..f42da35 100644
--- a/src/wikiget/wikiget.py
+++ b/src/wikiget/wikiget.py
@@ -151,8 +151,11 @@ def main() -> None:
except ParseError as e:
logger.error(e)
sys.exit(1)
+ except FileExistsError:
+ sys.exit(1)
except (ConnectionError, HTTPError, InvalidResponse, LoginError, APIError):
sys.exit(1)
+
errors = download(file, args)
if errors:
sys.exit(1) # completed with errors
diff --git a/tests/test_dl.py b/tests/test_dl.py
index 5c962f9..69ff2bb 100644
--- a/tests/test_dl.py
+++ b/tests/test_dl.py
@@ -15,6 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with Wikiget. If not, see <https://www.gnu.org/licenses/>.
+from pathlib import Path
+
import pytest
from wikiget.dl import prep_download
@@ -22,9 +24,23 @@ from wikiget.wikiget import construct_parser
# TODO: don't hit the actual API when doing tests
-@pytest.mark.skip(reason="skip tests that query a live API")
class TestPrepDownload:
+ @pytest.mark.skip(reason="skip tests that query a live API")
def test_prep_download(self) -> None:
+ """
+ The prep_download function should create a file object.
+ """
args = construct_parser().parse_args(["File:Example.jpg"])
file = prep_download(args.FILE, args)
assert file is not None
+
+ def test_prep_download_with_existing_file(self, tmp_path: Path) -> None:
+ """
+ Attempting to download a file with the same destination name as an existing file
+ should raise a FileExistsError.
+ """
+ tmp_file = tmp_path / "File:Example.jpg"
+ tmp_file.write_text("nothing")
+ args = construct_parser().parse_args(["File:Example.jpg", "-o", str(tmp_file)])
+ with pytest.raises(FileExistsError):
+ _ = prep_download(args.FILE, args)