diff options
| author | Cody Logan <cody@lokken.dev> | 2023-11-30 10:11:21 -0800 |
|---|---|---|
| committer | Cody Logan <cody@lokken.dev> | 2023-11-30 10:35:48 -0800 |
| commit | 26b2bfea7434aeb9d3687397341e0e7ad3f4edfc (patch) | |
| tree | 3eceae1d348454734bfd34f38728d33ce6f900e8 | |
| parent | 61b3733efd7b28bc2d3601aa9609a1119630c9ab (diff) | |
| download | wikiget-26b2bfea7434aeb9d3687397341e0e7ad3f4edfc.tar.gz wikiget-26b2bfea7434aeb9d3687397341e0e7ad3f4edfc.zip | |
Have functions return an exit code instead of calling sys.exit
| -rw-r--r-- | src/wikiget/dl.py | 3 | ||||
| -rw-r--r-- | src/wikiget/wikiget.py | 5 | ||||
| -rw-r--r-- | tests/test_dl.py | 4 | ||||
| -rw-r--r-- | tests/test_wikiget_cli.py | 20 |
4 files changed, 14 insertions, 18 deletions
diff --git a/src/wikiget/dl.py b/src/wikiget/dl.py index 85c6685..5a6ef0b 100644 --- a/src/wikiget/dl.py +++ b/src/wikiget/dl.py @@ -20,7 +20,6 @@ from __future__ import annotations import logging -import sys from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING @@ -134,7 +133,7 @@ def batch_download(args: Namespace) -> int: dl_dict = read_batch_file(args.FILE) except OSError as e: logger.error("File could not be read: %s", str(e)) - sys.exit(1) + return 1 with ThreadPoolExecutor(max_workers=args.threads) as executor: futures = [] diff --git a/src/wikiget/wikiget.py b/src/wikiget/wikiget.py index 152953c..0a6478e 100644 --- a/src/wikiget/wikiget.py +++ b/src/wikiget/wikiget.py @@ -129,7 +129,7 @@ def parse_args(argv: list[str]) -> argparse.Namespace: return parser.parse_args(argv) -def cli() -> None: +def cli() -> int: """Set up the command-line environment and start the download process.""" args = parse_args(sys.argv[1:]) configure_logging(verbosity=args.verbose, logfile=args.logfile, quiet=args.quiet) @@ -146,5 +146,4 @@ def cli() -> None: except KeyboardInterrupt: logger.critical("Interrupted by user") exit_code = 130 - finally: - sys.exit(exit_code) + return exit_code diff --git a/tests/test_dl.py b/tests/test_dl.py index 7117e50..63408eb 100644 --- a/tests/test_dl.py +++ b/tests/test_dl.py @@ -241,10 +241,10 @@ class TestBatchDownload: mock_read_batch_file.side_effect = OSError("error message") args = parse_args(["-a", "batch.txt"]) - with pytest.raises(SystemExit): - _ = batch_download(args) + errors = batch_download(args) assert mock_read_batch_file.called + assert errors == 1 assert caplog.record_tuples == [ ("wikiget.dl", logging.ERROR, "File could not be read: error message"), ] diff --git a/tests/test_wikiget_cli.py b/tests/test_wikiget_cli.py index 898ab8e..1ac3304 100644 --- a/tests/test_wikiget_cli.py +++ b/tests/test_wikiget_cli.py @@ -37,6 +37,7 @@ class TestWikigetCli: with monkeypatch.context() as m: m.setattr("sys.argv", ["wikiget"]) + # this SystemExit exception is raised by argparse with pytest.raises(SystemExit) as e: cli() @@ -51,10 +52,9 @@ class TestWikigetCli: mock_process_download.return_value = 0 m.setattr("sys.argv", ["wikiget", "File:Example.jpg"]) - with pytest.raises(SystemExit) as e: - cli() + code = cli() - assert e.value.code == 0 + assert code == 0 def test_cli_completed_with_problems( self, mock_process_download: MagicMock, monkeypatch: pytest.MonkeyPatch @@ -65,10 +65,9 @@ class TestWikigetCli: mock_process_download.return_value = 1 m.setattr("sys.argv", ["wikiget", "File:Example.jpg"]) - with pytest.raises(SystemExit) as e: - cli() + code = cli() - assert e.value.code == 1 + assert code == 1 def test_cli_logs( self, @@ -86,9 +85,9 @@ class TestWikigetCli: mock_process_download.return_value = 0 m.setattr("sys.argv", ["wikiget", "File:Example.jpg"]) - with pytest.raises(SystemExit): - cli() + code = cli() + assert code == 0 assert caplog.record_tuples == [ ( "wikiget.wikiget", @@ -116,10 +115,9 @@ class TestWikigetCli: mock_process_download.side_effect = KeyboardInterrupt m.setattr("sys.argv", ["wikiget", "File:Example.jpg"]) - with pytest.raises(SystemExit) as e: - cli() + code = cli() - assert e.value.code == 130 + assert code == 130 # ignore the first two messages, since they're tested elsewhere assert caplog.record_tuples[2] == ( "wikiget.wikiget", |
