Chore: Update typing and baselines again (#12641)

a
This commit is contained in:
Trenton H
2026-04-28 09:28:05 -07:00
committed by GitHub
parent ff95512b9a
commit 14fe520319
42 changed files with 4698 additions and 6333 deletions
+301 -606
View File
File diff suppressed because it is too large Load Diff
+3779 -5267
View File
File diff suppressed because it is too large Load Diff
+2
View File
@@ -178,6 +178,8 @@ respect-gitignore = true
fix = true
show-fixes = true
output-format = "grouped"
[tool.ruff.format]
line-ending = "lf"
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
extend-select = [
@@ -6,6 +6,7 @@ import tempfile
from itertools import islice
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from allauth.mfa.models import Authenticator
from allauth.socialaccount.models import SocialAccount
@@ -68,7 +69,7 @@ from paperless_mail.models import MailRule
def serialize_queryset_batched(
queryset: "QuerySet",
queryset: "QuerySet[Any]",
*,
batch_size: int = 500,
) -> "Generator[list[dict], None, None]":
@@ -364,7 +365,7 @@ class Command(CryptMixin, PaperlessCommand):
# 2. Create manifest, containing all correspondents, types, tags, storage paths
# note, documents and ui_settings
manifest_key_to_object_query: dict[str, QuerySet] = {
manifest_key_to_object_query: dict[str, QuerySet[Any]] = {
"correspondents": Correspondent.objects.all(),
"tags": Tag.objects.all(),
"document_types": DocumentType.objects.all(),
+1 -1
View File
@@ -261,7 +261,7 @@ def get_objects_for_user_owner_aware(
Model: Any,
*,
include_deleted: bool = False,
) -> QuerySet:
) -> QuerySet[Any]:
"""
Returns objects the user owns, are unowned, or has explicit perms.
When include_deleted is True, soft-deleted items are also included.
+17 -17
View File
@@ -214,7 +214,7 @@ class SetPermissionsMixin:
set_permissions_for_object(permissions, object)
class SerializerWithPerms(serializers.Serializer):
class SerializerWithPerms(serializers.Serializer[dict[str, Any]]):
def __init__(self, *args, **kwargs) -> None:
self.user = kwargs.pop("user", None)
self.full_perms = kwargs.pop("full_perms", False)
@@ -961,20 +961,12 @@ def _get_viewable_duplicates(
return duplicates.filter(id__in=allowed)
class DuplicateDocumentSummarySerializer(serializers.Serializer):
class DuplicateDocumentSummarySerializer(serializers.Serializer[dict[str, Any]]):
id = serializers.IntegerField()
title = serializers.CharField()
deleted_at = serializers.DateTimeField(allow_null=True)
class DocumentVersionInfoSerializer(serializers.Serializer):
id = serializers.IntegerField()
added = serializers.DateTimeField()
version_label = serializers.CharField(required=False, allow_null=True)
checksum = serializers.CharField(required=False, allow_null=True)
is_root = serializers.BooleanField()
class _DocumentVersionInfo(TypedDict):
id: int
added: datetime
@@ -983,6 +975,14 @@ class _DocumentVersionInfo(TypedDict):
is_root: bool
class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]):
id = serializers.IntegerField()
added = serializers.DateTimeField()
version_label = serializers.CharField(required=False, allow_null=True)
checksum = serializers.CharField(required=False, allow_null=True)
is_root = serializers.BooleanField()
@extend_schema_serializer(
deprecate_fields=["created_date"],
)
@@ -1532,7 +1532,7 @@ class SavedViewSerializer(OwnedObjectSerializer):
return saved_view
class DocumentListSerializer(serializers.Serializer):
class DocumentListSerializer(serializers.Serializer[dict[str, list[int]]]):
documents = serializers.ListField(
required=True,
label="Documents",
@@ -2085,7 +2085,7 @@ class BulkEditSerializer(
return attrs
class PostDocumentSerializer(serializers.Serializer):
class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
created = serializers.DateTimeField(
label="Created",
allow_null=True,
@@ -2262,7 +2262,7 @@ class PostDocumentSerializer(serializers.Serializer):
return created.date()
class DocumentVersionSerializer(serializers.Serializer):
class DocumentVersionSerializer(serializers.Serializer[dict[str, Any]]):
document = serializers.FileField(
label="Document",
write_only=True,
@@ -2278,7 +2278,7 @@ class DocumentVersionSerializer(serializers.Serializer):
validate_document = PostDocumentSerializer().validate_document
class DocumentVersionLabelSerializer(serializers.Serializer):
class DocumentVersionLabelSerializer(serializers.Serializer[dict[str, str | None]]):
version_label = serializers.CharField(
label="Version label",
required=True,
@@ -2484,7 +2484,7 @@ class TaskSerializerV10(OwnedObjectSerializer):
read_only_fields = fields
class TaskSerializerV9(serializers.ModelSerializer):
class TaskSerializerV9(serializers.ModelSerializer[PaperlessTask]):
"""Task serializer for API v9 backwards compatibility.
Maps old field names to the new model fields so existing clients continue
@@ -2609,7 +2609,7 @@ class TaskSerializerV9(serializers.ModelSerializer):
return list(qs.values("id", "title", "deleted_at"))
class TaskSummarySerializer(serializers.Serializer):
class TaskSummarySerializer(serializers.Serializer[dict[str, Any]]):
task_type = serializers.CharField()
total_count = serializers.IntegerField()
pending_count = serializers.IntegerField()
@@ -2622,7 +2622,7 @@ class TaskSummarySerializer(serializers.Serializer):
last_failure = serializers.DateTimeField(allow_null=True)
class RunTaskSerializer(serializers.Serializer):
class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
task_type = serializers.ChoiceField(
choices=PaperlessTask.TaskType.choices,
label="Task Type",
+6 -6
View File
@@ -16,7 +16,7 @@ from documents.models import StoragePath
from documents.models import Tag
class CorrespondentFactory(DjangoModelFactory):
class CorrespondentFactory(DjangoModelFactory[Correspondent]):
class Meta:
model = Correspondent
@@ -25,7 +25,7 @@ class CorrespondentFactory(DjangoModelFactory):
matching_algorithm = MatchingModel.MATCH_NONE
class DocumentTypeFactory(DjangoModelFactory):
class DocumentTypeFactory(DjangoModelFactory[DocumentType]):
class Meta:
model = DocumentType
@@ -34,7 +34,7 @@ class DocumentTypeFactory(DjangoModelFactory):
matching_algorithm = MatchingModel.MATCH_NONE
class TagFactory(DjangoModelFactory):
class TagFactory(DjangoModelFactory[Tag]):
class Meta:
model = Tag
@@ -44,7 +44,7 @@ class TagFactory(DjangoModelFactory):
is_inbox_tag = False
class StoragePathFactory(DjangoModelFactory):
class StoragePathFactory(DjangoModelFactory[StoragePath]):
class Meta:
model = StoragePath
@@ -56,7 +56,7 @@ class StoragePathFactory(DjangoModelFactory):
matching_algorithm = MatchingModel.MATCH_NONE
class DocumentFactory(DjangoModelFactory):
class DocumentFactory(DjangoModelFactory[Document]):
class Meta:
model = Document
@@ -68,7 +68,7 @@ class DocumentFactory(DjangoModelFactory):
storage_path = None
class PaperlessTaskFactory(DjangoModelFactory):
class PaperlessTaskFactory(DjangoModelFactory[PaperlessTask]):
class Meta:
model = PaperlessTask
@@ -4,6 +4,7 @@ from __future__ import annotations
import io
from typing import TYPE_CHECKING
from typing import Any
import pytest
from django.core.management import CommandError
@@ -122,7 +123,7 @@ def mock_queryset():
This verifies we use .count() instead of len() for querysets.
"""
class MockQuerySet(QuerySet):
class MockQuerySet(QuerySet[Any]):
def __init__(self, items: list):
self._items = items
self.count_called = False
@@ -147,7 +148,7 @@ def mock_queryset():
class TestProcessResult:
"""Tests for the ProcessResult dataclass."""
def test_success_result(self):
def test_success_result(self) -> None:
result = ProcessResult(item=1, result=2, error=None)
assert result.item == 1
@@ -155,7 +156,7 @@ class TestProcessResult:
assert result.error is None
assert result.success is True
def test_error_result(self):
def test_error_result(self) -> None:
error = ValueError("test error")
result = ProcessResult(item=1, result=None, error=error)
@@ -169,7 +170,7 @@ class TestProcessResult:
class TestPaperlessCommandArguments:
"""Tests for argument parsing behavior."""
def test_progress_bar_argument_added_by_default(self):
def test_progress_bar_argument_added_by_default(self) -> None:
command = SimpleCommand()
parser = command.create_parser("manage.py", "simple")
@@ -179,14 +180,14 @@ class TestPaperlessCommandArguments:
options = parser.parse_args([])
assert options.no_progress_bar is False
def test_progress_bar_argument_not_added_when_disabled(self):
def test_progress_bar_argument_not_added_when_disabled(self) -> None:
command = NoProgressBarCommand()
parser = command.create_parser("manage.py", "noprogress")
options = parser.parse_args([])
assert not hasattr(options, "no_progress_bar")
def test_processes_argument_added_when_multiprocessing_enabled(self):
def test_processes_argument_added_when_multiprocessing_enabled(self) -> None:
command = MultiprocessCommand()
parser = command.create_parser("manage.py", "multiprocess")
@@ -196,7 +197,7 @@ class TestPaperlessCommandArguments:
options = parser.parse_args([])
assert options.processes >= 1
def test_processes_argument_not_added_when_multiprocessing_disabled(self):
def test_processes_argument_not_added_when_multiprocessing_disabled(self) -> None:
command = SimpleCommand()
parser = command.create_parser("manage.py", "simple")
@@ -231,7 +232,7 @@ class TestPaperlessCommandExecute:
*,
no_progress_bar_flag: bool,
expected: bool,
):
) -> None:
command = SimpleCommand()
command.stdout = io.StringIO()
command.stderr = io.StringIO()
@@ -241,7 +242,10 @@ class TestPaperlessCommandExecute:
assert command.no_progress_bar is expected
def test_no_progress_bar_always_true_when_not_supported(self, base_options: dict):
def test_no_progress_bar_always_true_when_not_supported(
self,
base_options: dict,
) -> None:
command = NoProgressBarCommand()
command.stdout = io.StringIO()
command.stderr = io.StringIO()
@@ -262,7 +266,7 @@ class TestPaperlessCommandExecute:
base_options: dict,
processes: int,
expected: int,
):
) -> None:
command = MultiprocessCommand()
command.stdout = io.StringIO()
command.stderr = io.StringIO()
@@ -283,7 +287,7 @@ class TestPaperlessCommandExecute:
self,
base_options: dict,
invalid_count: int,
):
) -> None:
command = MultiprocessCommand()
command.stdout = io.StringIO()
command.stderr = io.StringIO()
@@ -293,7 +297,10 @@ class TestPaperlessCommandExecute:
with pytest.raises(CommandError, match="--processes must be at least 1"):
command.execute(**options)
def test_process_count_defaults_to_one_when_not_supported(self, base_options: dict):
def test_process_count_defaults_to_one_when_not_supported(
self,
base_options: dict,
) -> None:
command = SimpleCommand()
command.stdout = io.StringIO()
command.stderr = io.StringIO()
@@ -312,7 +319,7 @@ class TestGetIterableLength:
self,
simple_command: SimpleCommand,
mock_queryset,
):
) -> None:
"""Should call .count() on Django querysets rather than len()."""
queryset = mock_queryset([1, 2, 3, 4, 5])
@@ -321,13 +328,16 @@ class TestGetIterableLength:
assert result == 5
assert queryset.count_called is True
def test_uses_len_for_sized(self, simple_command: SimpleCommand):
def test_uses_len_for_sized(self, simple_command: SimpleCommand) -> None:
"""Should use len() for sequences and other Sized types."""
result = simple_command._get_iterable_length([1, 2, 3, 4])
assert result == 4
def test_returns_none_for_unsized_iterables(self, simple_command: SimpleCommand):
def test_returns_none_for_unsized_iterables(
self,
simple_command: SimpleCommand,
) -> None:
"""Should return None for generators and other iterables without len()."""
result = simple_command._get_iterable_length(x for x in [1, 2, 3])
@@ -338,7 +348,7 @@ class TestGetIterableLength:
class TestTrack:
"""Tests for the track() method."""
def test_with_progress_bar_disabled(self, simple_command: SimpleCommand):
def test_with_progress_bar_disabled(self, simple_command: SimpleCommand) -> None:
simple_command.no_progress_bar = True
items = ["a", "b", "c"]
@@ -346,7 +356,7 @@ class TestTrack:
assert result == items
def test_with_progress_bar_enabled(self, simple_command: SimpleCommand):
def test_with_progress_bar_enabled(self, simple_command: SimpleCommand) -> None:
simple_command.no_progress_bar = False
items = [1, 2, 3]
@@ -354,7 +364,7 @@ class TestTrack:
assert result == items
def test_with_explicit_total(self, simple_command: SimpleCommand):
def test_with_explicit_total(self, simple_command: SimpleCommand) -> None:
simple_command.no_progress_bar = False
def gen():
@@ -364,7 +374,7 @@ class TestTrack:
assert result == [1, 2, 3]
def test_with_generator_no_total(self, simple_command: SimpleCommand):
def test_with_generator_no_total(self, simple_command: SimpleCommand) -> None:
def gen():
yield from [1, 2, 3]
@@ -372,7 +382,7 @@ class TestTrack:
assert result == [1, 2, 3]
def test_empty_iterable(self, simple_command: SimpleCommand):
def test_empty_iterable(self, simple_command: SimpleCommand) -> None:
result = list(simple_command.track([]))
assert result == []
@@ -382,7 +392,7 @@ class TestTrack:
simple_command: SimpleCommand,
mock_queryset,
mocker: MockerFixture,
):
) -> None:
"""Verify track() uses .count() for querysets."""
simple_command.no_progress_bar = False
queryset = mock_queryset([1, 2, 3])
@@ -403,7 +413,7 @@ class TestProcessParallel:
def test_sequential_processing_single_process(
self,
multiprocess_command: MultiprocessCommand,
):
) -> None:
multiprocess_command.process_count = 1
items = [1, 2, 3, 4, 5]
@@ -418,7 +428,7 @@ class TestProcessParallel:
def test_sequential_processing_handles_errors(
self,
multiprocess_command: MultiprocessCommand,
):
) -> None:
multiprocess_command.process_count = 1
items = [1, 2, 0, 4] # 0 causes ZeroDivisionError
@@ -438,7 +448,7 @@ class TestProcessParallel:
self,
multiprocess_command: MultiprocessCommand,
mocker: MockerFixture,
):
) -> None:
multiprocess_command.process_count = 2
items = [1, 2, 3]
@@ -455,7 +465,7 @@ class TestProcessParallel:
self,
multiprocess_command: MultiprocessCommand,
mocker: MockerFixture,
):
) -> None:
multiprocess_command.process_count = 2
items = [1, 2, 0, 4]
@@ -467,7 +477,7 @@ class TestProcessParallel:
assert len(failures) == 1
assert failures[0].item == 0
def test_empty_items(self, multiprocess_command: MultiprocessCommand):
def test_empty_items(self, multiprocess_command: MultiprocessCommand) -> None:
results = list(multiprocess_command.process_parallel(_double_value, []))
assert results == []
@@ -475,7 +485,7 @@ class TestProcessParallel:
def test_result_contains_original_item(
self,
multiprocess_command: MultiprocessCommand,
):
) -> None:
items = [10, 20, 30]
results = list(multiprocess_command.process_parallel(_double_value, items))
@@ -488,7 +498,7 @@ class TestProcessParallel:
self,
multiprocess_command: MultiprocessCommand,
mocker: MockerFixture,
):
) -> None:
"""Verify single process uses sequential path (important for testing)."""
multiprocess_command.process_count = 1
@@ -504,7 +514,7 @@ class TestProcessParallel:
self,
multiprocess_command: MultiprocessCommand,
mocker: MockerFixture,
):
) -> None:
"""Verify multiple processes uses parallel path."""
multiprocess_command.process_count = 2
+49 -31
View File
@@ -16,7 +16,7 @@ pytestmark = [pytest.mark.search, pytest.mark.django_db]
class TestWriteBatch:
"""Test WriteBatch context manager functionality."""
def test_rolls_back_on_exception(self, backend: TantivyBackend):
def test_rolls_back_on_exception(self, backend: TantivyBackend) -> None:
"""Batch operations must rollback on exception to preserve index integrity."""
doc = Document.objects.create(
title="Rollback Target",
@@ -43,7 +43,7 @@ class TestSearch:
def test_text_mode_limits_default_search_to_title_and_content(
self,
backend: TantivyBackend,
):
) -> None:
"""Simple text mode must not match metadata-only fields."""
doc = Document.objects.create(
title="Invoice document",
@@ -71,7 +71,7 @@ class TestSearch:
def test_title_mode_limits_default_search_to_title_only(
self,
backend: TantivyBackend,
):
) -> None:
"""Title mode must not match content-only terms."""
doc = Document.objects.create(
title="Invoice document",
@@ -93,7 +93,7 @@ class TestSearch:
def test_text_mode_matches_partial_term_substrings(
self,
backend: TantivyBackend,
):
) -> None:
"""Simple text mode should support substring matching within tokens."""
doc = Document.objects.create(
title="Account access",
@@ -117,7 +117,7 @@ class TestSearch:
def test_text_mode_does_not_match_on_partial_term_overlap(
self,
backend: TantivyBackend,
):
) -> None:
"""Simple text mode should not match documents that merely share partial fragments."""
doc = Document.objects.create(
title="Adobe Acrobat PDF Files",
@@ -135,7 +135,7 @@ class TestSearch:
def test_text_mode_anchors_later_query_tokens_to_token_starts(
self,
backend: TantivyBackend,
):
) -> None:
"""Multi-token simple search should not match later tokens in the middle of a word."""
exact_doc = Document.objects.create(
title="Z-Berichte 6",
@@ -170,7 +170,7 @@ class TestSearch:
def test_text_mode_ignores_queries_without_searchable_tokens(
self,
backend: TantivyBackend,
):
) -> None:
"""Simple text mode should safely return no hits for symbol-only strings."""
doc = Document.objects.create(
title="Guide",
@@ -187,7 +187,7 @@ class TestSearch:
def test_title_mode_matches_partial_term_substrings(
self,
backend: TantivyBackend,
):
) -> None:
"""Title mode should support substring matching within title tokens."""
doc = Document.objects.create(
title="Password guide",
@@ -210,7 +210,7 @@ class TestSearch:
== 1
)
def test_sort_field_ascending(self, backend: TantivyBackend):
def test_sort_field_ascending(self, backend: TantivyBackend) -> None:
"""Searching with sort_reverse=False must return results in ascending ASN order."""
for asn in [30, 10, 20]:
doc = Document.objects.create(
@@ -231,7 +231,7 @@ class TestSearch:
asns = [Document.objects.get(pk=doc_id).archive_serial_number for doc_id in ids]
assert asns == [10, 20, 30]
def test_sort_field_descending(self, backend: TantivyBackend):
def test_sort_field_descending(self, backend: TantivyBackend) -> None:
"""Searching with sort_reverse=True must return results in descending ASN order."""
for asn in [30, 10, 20]:
doc = Document.objects.create(
@@ -256,7 +256,7 @@ class TestSearch:
class TestSearchIds:
"""Test lightweight ID-only search."""
def test_returns_matching_ids(self, backend: TantivyBackend):
def test_returns_matching_ids(self, backend: TantivyBackend) -> None:
"""search_ids must return IDs of all matching documents."""
docs = []
for i in range(5):
@@ -282,7 +282,7 @@ class TestSearchIds:
assert set(ids) == {d.pk for d in docs}
assert other.pk not in ids
def test_respects_permission_filter(self, backend: TantivyBackend):
def test_respects_permission_filter(self, backend: TantivyBackend) -> None:
"""search_ids must respect user permission filtering."""
owner = User.objects.create_user("ids_owner")
other = User.objects.create_user("ids_other")
@@ -303,7 +303,7 @@ class TestSearchIds:
backend.search_ids("secret", user=other, search_mode=SearchMode.QUERY) == []
)
def test_respects_fuzzy_threshold(self, backend: TantivyBackend, settings):
def test_respects_fuzzy_threshold(self, backend: TantivyBackend, settings) -> None:
"""search_ids must apply the same fuzzy threshold as search()."""
doc = Document.objects.create(
title="threshold test",
@@ -316,7 +316,7 @@ class TestSearchIds:
ids = backend.search_ids("unique", user=None, search_mode=SearchMode.QUERY)
assert ids == []
def test_returns_ids_for_text_mode(self, backend: TantivyBackend):
def test_returns_ids_for_text_mode(self, backend: TantivyBackend) -> None:
"""search_ids must work with TEXT search mode."""
doc = Document.objects.create(
title="text mode doc",
@@ -332,7 +332,7 @@ class TestSearchIds:
class TestRebuild:
"""Test index rebuilding functionality."""
def test_with_iter_wrapper_called(self, backend: TantivyBackend):
def test_with_iter_wrapper_called(self, backend: TantivyBackend) -> None:
"""Index rebuild must pass documents through iter_wrapper for progress tracking."""
seen = []
@@ -349,7 +349,7 @@ class TestRebuild:
class TestAutocomplete:
"""Test autocomplete functionality."""
def test_basic_functionality(self, backend: TantivyBackend):
def test_basic_functionality(self, backend: TantivyBackend) -> None:
"""Autocomplete must return words matching the given prefix."""
doc = Document.objects.create(
title="Invoice from Microsoft Corporation",
@@ -362,7 +362,10 @@ class TestAutocomplete:
results = backend.autocomplete("micro", limit=10)
assert "microsoft" in results
def test_results_ordered_by_document_frequency(self, backend: TantivyBackend):
def test_results_ordered_by_document_frequency(
self,
backend: TantivyBackend,
) -> None:
"""Autocomplete results must be ordered by document frequency to prioritize common terms."""
# "payment" appears in 3 docs; "payslip" in 1 — "pay" prefix should
# return "payment" before "payslip".
@@ -390,7 +393,10 @@ class TestAutocomplete:
class TestMoreLikeThis:
"""Test more like this functionality."""
def test_more_like_this_ids_excludes_original(self, backend: TantivyBackend):
def test_more_like_this_ids_excludes_original(
self,
backend: TantivyBackend,
) -> None:
"""more_like_this_ids must return IDs of similar documents, excluding the original."""
doc1 = Document.objects.create(
title="Important document",
@@ -421,11 +427,11 @@ class TestSingleton:
yield
reset_backend()
def test_returns_same_instance_on_repeated_calls(self, index_dir):
def test_returns_same_instance_on_repeated_calls(self, index_dir) -> None:
"""Singleton pattern: repeated calls to get_backend() must return the same instance."""
assert get_backend() is get_backend()
def test_reinitializes_when_index_dir_changes(self, tmp_path, settings):
def test_reinitializes_when_index_dir_changes(self, tmp_path, settings) -> None:
"""Backend singleton must reinitialize when INDEX_DIR setting changes for test isolation."""
settings.INDEX_DIR = tmp_path / "a"
(tmp_path / "a").mkdir()
@@ -438,7 +444,7 @@ class TestSingleton:
assert b1 is not b2
assert b2._path == tmp_path / "b"
def test_reset_forces_new_instance(self, index_dir):
def test_reset_forces_new_instance(self, index_dir) -> None:
"""reset_backend() must force creation of a new backend instance on next get_backend() call."""
b1 = get_backend()
reset_backend()
@@ -449,7 +455,7 @@ class TestSingleton:
class TestFieldHandling:
"""Test handling of various document fields."""
def test_none_values_handled_correctly(self, backend: TantivyBackend):
def test_none_values_handled_correctly(self, backend: TantivyBackend) -> None:
"""Document fields with None values must not cause indexing errors."""
doc = Document.objects.create(
title="Test Doc",
@@ -464,7 +470,10 @@ class TestFieldHandling:
assert len(backend.search_ids("test", user=None)) == 1
def test_custom_fields_include_name_and_value(self, backend: TantivyBackend):
def test_custom_fields_include_name_and_value(
self,
backend: TantivyBackend,
) -> None:
"""Custom fields must be indexed with both field name and value for structured queries."""
field = CustomField.objects.create(
name="Invoice Number",
@@ -486,7 +495,10 @@ class TestFieldHandling:
assert len(backend.search_ids("invoice", user=None)) == 1
def test_select_custom_field_indexes_label_not_id(self, backend: TantivyBackend):
def test_select_custom_field_indexes_label_not_id(
self,
backend: TantivyBackend,
) -> None:
"""SELECT custom fields must index the human-readable label, not the opaque option ID."""
field = CustomField.objects.create(
name="Category",
@@ -514,7 +526,7 @@ class TestFieldHandling:
assert len(backend.search_ids("custom_fields.value:invoice", user=None)) == 1
assert len(backend.search_ids("custom_fields.value:opt_abc", user=None)) == 0
def test_none_custom_field_value_not_indexed(self, backend: TantivyBackend):
def test_none_custom_field_value_not_indexed(self, backend: TantivyBackend) -> None:
"""Custom field instances with no value set must not produce an index entry."""
field = CustomField.objects.create(
name="Optional",
@@ -536,7 +548,7 @@ class TestFieldHandling:
assert len(backend.search_ids("custom_fields.value:none", user=None)) == 0
def test_notes_include_user_information(self, backend: TantivyBackend):
def test_notes_include_user_information(self, backend: TantivyBackend) -> None:
"""Notes must be indexed with user information when available for structured queries."""
user = User.objects.create_user("notewriter")
doc = Document.objects.create(
@@ -566,7 +578,7 @@ class TestHighlightHits:
def test_highlights_simple_text_mode_returns_html_string(
self,
backend: TantivyBackend,
):
) -> None:
"""Simple text search should still produce content highlights for exact-token hits."""
doc = Document.objects.create(
title="Highlight Test",
@@ -583,7 +595,10 @@ class TestHighlightHits:
assert "content" in highlights
assert "<b>" in highlights["content"]
def test_highlights_content_returns_html_string(self, backend: TantivyBackend):
def test_highlights_content_returns_html_string(
self,
backend: TantivyBackend,
) -> None:
"""highlight_hits must return HTML strings (from Snippet.to_html()), not Snippet objects."""
doc = Document.objects.create(
title="Highlight Test",
@@ -607,7 +622,10 @@ class TestHighlightHits:
f"Expected HTML with <b> tags, got: {content_highlight!r}"
)
def test_highlights_notes_returns_html_string(self, backend: TantivyBackend):
def test_highlights_notes_returns_html_string(
self,
backend: TantivyBackend,
) -> None:
"""Note highlights must be HTML strings via notes_text companion field.
The notes JSON field does not support tantivy SnippetGenerator; the
@@ -642,12 +660,12 @@ class TestHighlightHits:
f"Expected HTML with <b> tags, got: {note_highlight!r}"
)
def test_empty_doc_list_returns_empty_hits(self, backend: TantivyBackend):
def test_empty_doc_list_returns_empty_hits(self, backend: TantivyBackend) -> None:
"""highlight_hits with no doc IDs must return an empty list."""
hits = backend.highlight_hits("anything", [])
assert hits == []
def test_no_highlights_when_no_match(self, backend: TantivyBackend):
def test_no_highlights_when_no_match(self, backend: TantivyBackend) -> None:
"""Documents not matching the query should not appear in results."""
doc = Document.objects.create(
title="Unrelated",
@@ -79,60 +79,60 @@ class TestMigrateFulltextQueryFieldPrefixes(TestMigrations):
value="note:something",
)
def test_note_prefix_rewritten(self):
def test_note_prefix_rewritten(self) -> None:
self.rule_note.refresh_from_db()
self.assertEqual(self.rule_note.value, "notes.note:invoice")
def test_custom_field_prefix_rewritten(self):
def test_custom_field_prefix_rewritten(self) -> None:
self.rule_cf.refresh_from_db()
self.assertEqual(self.rule_cf.value, "custom_fields.value:amount")
def test_combined_query_rewritten(self):
def test_combined_query_rewritten(self) -> None:
self.rule_combined.refresh_from_db()
self.assertEqual(
self.rule_combined.value,
"notes.note:invoice AND custom_fields.value:total",
)
def test_parenthesized_groups(self):
def test_parenthesized_groups(self) -> None:
self.rule_parens.refresh_from_db()
self.assertEqual(
self.rule_parens.value,
"(notes.note:invoice OR notes.note:receipt)",
)
def test_plus_prefix(self):
def test_plus_prefix(self) -> None:
self.rule_plus.refresh_from_db()
self.assertEqual(self.rule_plus.value, "+notes.note:foo")
def test_minus_prefix(self):
def test_minus_prefix(self) -> None:
self.rule_minus.refresh_from_db()
self.assertEqual(self.rule_minus.value, "-notes.note:bar")
def test_boosted(self):
def test_boosted(self) -> None:
self.rule_boost.refresh_from_db()
self.assertEqual(self.rule_boost.value, "notes.note:test^2")
def test_no_match_unchanged(self):
def test_no_match_unchanged(self) -> None:
self.rule_no_match.refresh_from_db()
self.assertEqual(self.rule_no_match.value, "title:hello content:world")
def test_word_boundary_no_false_positive(self):
def test_word_boundary_no_false_positive(self) -> None:
self.rule_denote.refresh_from_db()
self.assertEqual(self.rule_denote.value, "denote:foo")
def test_already_migrated_idempotent(self):
def test_already_migrated_idempotent(self) -> None:
self.rule_already_migrated.refresh_from_db()
self.assertEqual(self.rule_already_migrated.value, "notes.note:foo")
def test_already_migrated_cf_idempotent(self):
def test_already_migrated_cf_idempotent(self) -> None:
self.rule_already_migrated_cf.refresh_from_db()
self.assertEqual(self.rule_already_migrated_cf.value, "custom_fields.value:bar")
def test_null_value_no_crash(self):
def test_null_value_no_crash(self) -> None:
self.rule_null.refresh_from_db()
self.assertIsNone(self.rule_null.value)
def test_non_fulltext_rule_untouched(self):
def test_non_fulltext_rule_untouched(self) -> None:
self.rule_other_type.refresh_from_db()
self.assertEqual(self.rule_other_type.value, "note:something")
+1 -1
View File
@@ -100,7 +100,7 @@ class TestTagAdmin(DirectoriesMixin, TestCase):
self.tag_admin = TagAdmin(model=Tag, admin_site=AdminSite())
@patch("documents.tasks.bulk_update_documents")
def test_parent_tags_get_added(self, mock_bulk_update):
def test_parent_tags_get_added(self, mock_bulk_update) -> None:
document = Document.objects.create(title="test")
parent = Tag.objects.create(name="parent")
child = Tag.objects.create(name="child")
@@ -91,6 +91,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
- app_title and app_logo are included
"""
config = ApplicationConfiguration.objects.first()
assert config is not None
config.app_title = "Fancy New Title"
config.app_logo = "/logo/example.jpg"
config.save()
@@ -125,6 +126,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
assert config is not None
self.assertEqual(config.color_conversion_strategy, ColorConvertChoices.RGB)
def test_api_update_config_empty_fields(self) -> None:
@@ -150,6 +152,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
assert config is not None
self.assertEqual(config.user_args, None)
self.assertEqual(config.language, None)
self.assertEqual(config.barcode_tag_mapping, None)
@@ -187,6 +190,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
self.assertIn("image/jpeg", response["Content-Type"])
config = ApplicationConfiguration.objects.first()
assert config is not None
old_logo = config.app_logo
self.assertTrue(Path(old_logo.path).exists())
self.client.patch(
@@ -233,6 +237,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
assert config is not None
with Image.open(config.app_logo.path) as stored_logo:
stored_exif = stored_logo.getexif()
@@ -268,6 +273,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
assert config is not None
with Image.open(config.app_logo.path) as stored_logo:
stored_text = stored_logo.text
@@ -786,6 +792,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
- llm_api_key is set to None
"""
config = ApplicationConfiguration.objects.first()
assert config is not None
config.llm_api_key = "1234567890"
config.save()
@@ -826,6 +833,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
- LLM index is triggered to update
"""
config = ApplicationConfiguration.objects.first()
assert config is not None
config.ai_enabled = False
config.llm_embedding_backend = None
config.save()
+1 -1
View File
@@ -918,7 +918,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
],
)
def test_api_selection_data_requires_view_permission(self):
def test_api_selection_data_requires_view_permission(self) -> None:
self.doc2.owner = self.user
self.doc2.save()
+48 -16
View File
@@ -276,7 +276,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
doc.refresh_from_db()
self.assertEqual(doc.custom_fields.first().value, None)
_cf_1 = doc.custom_fields.first()
assert _cf_1 is not None
self.assertEqual(_cf_1.value, None)
@mock.patch("documents.signals.handlers.process_cf_select_update.apply_async")
def test_custom_field_update_offloaded_once(self, mock_delay) -> None:
@@ -567,7 +569,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(CustomFieldInstance.objects.count(), 1)
self.assertEqual(doc.custom_fields.first().value, "test value")
_cf_2 = doc.custom_fields.first()
assert _cf_2 is not None
self.assertEqual(_cf_2.value, "test value")
# Update
resp = self.client.patch(
@@ -584,7 +588,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(CustomFieldInstance.objects.count(), 1)
self.assertEqual(doc.custom_fields.first().value, "a new test value")
_cf_3 = doc.custom_fields.first()
assert _cf_3 is not None
self.assertEqual(_cf_3.value, "a new test value")
def test_delete_custom_field_instance(self) -> None:
"""
@@ -650,7 +656,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(CustomFieldInstance.objects.count(), 1)
self.assertEqual(Document.objects.count(), 1)
self.assertEqual(len(doc.custom_fields.all()), 1)
self.assertEqual(doc.custom_fields.first().value, date_value)
_cf_4 = doc.custom_fields.first()
assert _cf_4 is not None
self.assertEqual(_cf_4.value, date_value)
def test_custom_field_validation(self) -> None:
"""
@@ -1062,9 +1070,15 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(CustomFieldInstance.objects.count(), 4)
self.assertEqual(doc2.custom_fields.first().value, [1])
self.assertEqual(doc3.custom_fields.first().value, [1])
self.assertEqual(doc4.custom_fields.first().value, [1])
_cf_5 = doc2.custom_fields.first()
assert _cf_5 is not None
self.assertEqual(_cf_5.value, [1])
_cf_6 = doc3.custom_fields.first()
assert _cf_6 is not None
self.assertEqual(_cf_6.value, [1])
_cf_7 = doc4.custom_fields.first()
assert _cf_7 is not None
self.assertEqual(_cf_7.value, [1])
# Add links appends if necessary
resp = self.client.patch(
@@ -1081,7 +1095,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(doc4.custom_fields.first().value, [1, 3])
_cf_8 = doc4.custom_fields.first()
assert _cf_8 is not None
self.assertEqual(_cf_8.value, [1, 3])
# Remove one of the links, removed on other doc
resp = self.client.patch(
@@ -1098,9 +1114,15 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(doc2.custom_fields.first().value, [1])
self.assertEqual(doc3.custom_fields.first().value, [1, 4])
self.assertEqual(doc4.custom_fields.first().value, [3])
_cf_9 = doc2.custom_fields.first()
assert _cf_9 is not None
self.assertEqual(_cf_9.value, [1])
_cf_10 = doc3.custom_fields.first()
assert _cf_10 is not None
self.assertEqual(_cf_10.value, [1, 4])
_cf_11 = doc4.custom_fields.first()
assert _cf_11 is not None
self.assertEqual(_cf_11.value, [3])
# Removes the field entirely
resp = self.client.patch(
@@ -1112,9 +1134,15 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(doc2.custom_fields.first().value, [])
self.assertEqual(doc3.custom_fields.first().value, [4])
self.assertEqual(doc4.custom_fields.first().value, [3])
_cf_12 = doc2.custom_fields.first()
assert _cf_12 is not None
self.assertEqual(_cf_12.value, [])
_cf_13 = doc3.custom_fields.first()
assert _cf_13 is not None
self.assertEqual(_cf_13.value, [4])
_cf_14 = doc4.custom_fields.first()
assert _cf_14 is not None
self.assertEqual(_cf_14.value, [3])
# If field exists on target doc but value is None
doc5 = Document.objects.create(
@@ -1139,7 +1167,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(doc5.custom_fields.first().value, [1])
_cf_15 = doc5.custom_fields.first()
assert _cf_15 is not None
self.assertEqual(_cf_15.value, [1])
def test_documentlink_patch_requires_change_permission_on_target_documents(
self,
@@ -1321,7 +1351,9 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
results = response.data["results"]
self.assertEqual(results[0]["document_count"], 0)
def test_patch_document_invalid_date_custom_field_returns_validation_error(self):
def test_patch_document_invalid_date_custom_field_returns_validation_error(
self,
) -> None:
"""
GIVEN:
- A date custom field
+6 -6
View File
@@ -1168,7 +1168,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
self.assertIn("all", response.data)
self.assertCountEqual(response.data["all"], [d.id for d in docs])
def test_default_ordering_uses_id_as_tiebreaker(self):
def test_default_ordering_uses_id_as_tiebreaker(self) -> None:
"""
GIVEN:
- Documents sharing the same created date
@@ -2156,7 +2156,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
match_tags,
match_document_types,
match_storage_paths,
):
) -> None:
doc = Document.objects.create(
title="test",
mime_type="application/pdf",
@@ -2193,7 +2193,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
match_document_types,
match_storage_paths,
mocked_load,
):
) -> None:
"""
GIVEN:
- Request for suggestions for a document
@@ -2276,7 +2276,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
def test_get_suggestions_dates_disabled(
self,
mock_get_date_parser: mock.MagicMock,
):
) -> None:
"""
GIVEN:
- NUMBER_OF_SUGGESTED_DATES = 0 (disables feature)
@@ -3409,7 +3409,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
def test_create_share_link_requires_view_permission_for_document(self):
def test_create_share_link_requires_view_permission_for_document(self) -> None:
"""
GIVEN:
- A user with add_sharelink but without view permission on a document
@@ -3457,7 +3457,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
self.assertEqual(create_resp.status_code, status.HTTP_201_CREATED)
self.assertEqual(create_resp.data["document"], doc.pk)
def test_next_asn(self):
def test_next_asn(self) -> None:
"""
GIVEN:
- Existing documents with ASNs, highest owned by user2
+1 -1
View File
@@ -933,7 +933,7 @@ class TestApiUser(DirectoriesMixin, APITestCase):
returned_user1 = User.objects.get(pk=user1.pk)
self.assertEqual(returned_user1.is_superuser, False)
def test_only_superusers_can_create_or_alter_staff_status(self):
def test_only_superusers_can_create_or_alter_staff_status(self) -> None:
"""
GIVEN:
- Existing user account
+24 -14
View File
@@ -79,14 +79,14 @@ class TestApiSchema(APITestCase):
class TestTasksSummarySchema:
"""tasks_summary_retrieve: response must be an array of TaskSummarySerializer."""
def test_summary_response_is_array(self, api_schema: SchemaGenerator):
def test_summary_response_is_array(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/tasks/summary/"]["get"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
assert resp_200["type"] == "array", (
"tasks_summary_retrieve response must be type:array"
)
def test_summary_items_have_total_count(self, api_schema: SchemaGenerator):
def test_summary_items_have_total_count(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/tasks/summary/"]["get"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
items = resp_200.get("items", {})
@@ -100,7 +100,10 @@ class TestTasksSummarySchema:
"summary items must have 'total_count' (TaskSummarySerializer)"
)
def test_summary_days_parameter_constraints(self, api_schema: SchemaGenerator):
def test_summary_days_parameter_constraints(
self,
api_schema: SchemaGenerator,
) -> None:
op = api_schema["paths"]["/api/tasks/summary/"]["get"]
params = {p["name"]: p for p in op.get("parameters", [])}
assert "days" in params, "days query parameter must be declared"
@@ -112,14 +115,14 @@ class TestTasksSummarySchema:
class TestTasksActiveSchema:
"""tasks_active_retrieve: response must be an array of TaskSerializerV10."""
def test_active_response_is_array(self, api_schema: SchemaGenerator):
def test_active_response_is_array(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/tasks/active/"]["get"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
assert resp_200["type"] == "array", (
"tasks_active_retrieve response must be type:array"
)
def test_active_items_ref_named_schema(self, api_schema: SchemaGenerator):
def test_active_items_ref_named_schema(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/tasks/active/"]["get"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
items = resp_200.get("items", {})
@@ -133,7 +136,11 @@ class TestMetadataSchema:
"""Metadata component: array fields and optional archive fields."""
@pytest.mark.parametrize("field", ["original_metadata", "archive_metadata"])
def test_metadata_field_is_array(self, api_schema: SchemaGenerator, field: str):
def test_metadata_field_is_array(
self,
api_schema: SchemaGenerator,
field: str,
) -> None:
props = api_schema["components"]["schemas"]["Metadata"]["properties"]
assert props[field]["type"] == "array", (
f"{field} should be type:array, not type:object"
@@ -144,7 +151,7 @@ class TestMetadataSchema:
self,
api_schema: SchemaGenerator,
field: str,
):
) -> None:
props = api_schema["components"]["schemas"]["Metadata"]["properties"]
items = props[field]["items"]
ref = items.get("$ref", "")
@@ -166,7 +173,7 @@ class TestMetadataSchema:
"archive_metadata",
],
)
def test_archive_field_not_required(self, api_schema, field):
def test_archive_field_not_required(self, api_schema, field) -> None:
schema = api_schema["components"]["schemas"]["Metadata"]
required = schema.get("required", [])
assert field not in required
@@ -179,7 +186,7 @@ class TestMetadataSchema:
class TestStoragePathTestSchema:
"""storage_paths_test_create: response must be a string, not a StoragePath object."""
def test_test_action_response_is_string(self, api_schema: SchemaGenerator):
def test_test_action_response_is_string(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/storage_paths/test/"]["post"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
assert resp_200.get("type") == "string", (
@@ -189,7 +196,7 @@ class TestStoragePathTestSchema:
def test_test_action_request_uses_storage_path_test_serializer(
self,
api_schema: SchemaGenerator,
):
) -> None:
op = api_schema["paths"]["/api/storage_paths/test/"]["post"]
content = (
op.get("requestBody", {}).get("content", {}).get("application/json", {})
@@ -220,11 +227,14 @@ class TestProcessedMailBulkDeleteSchema:
self,
api_schema: SchemaGenerator,
field: str,
):
) -> None:
props = self._get_props(api_schema)
assert field in props, f"bulk_delete 200 response must have a '{field}' field"
def test_bulk_delete_response_is_not_processed_mail_serializer(self, api_schema):
def test_bulk_delete_response_is_not_processed_mail_serializer(
self,
api_schema,
) -> None:
op = api_schema["paths"]["/api/processed_mail/bulk_delete/"]["post"]
resp_200 = op["responses"]["200"]["content"]["application/json"]["schema"]
ref = resp_200.get("$ref", "")
@@ -237,13 +247,13 @@ class TestProcessedMailBulkDeleteSchema:
class TestShareLinkBundleRebuildSchema:
"""share_link_bundles_rebuild_create: 200 returns bundle data; 400 is documented."""
def test_rebuild_has_400_response(self, api_schema: SchemaGenerator):
def test_rebuild_has_400_response(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/share_link_bundles/{id}/rebuild/"]["post"]
assert "400" in op["responses"], (
"rebuild must document the 400 response for 'Bundle is already being processed.'"
)
def test_rebuild_400_has_detail_field(self, api_schema: SchemaGenerator):
def test_rebuild_400_has_detail_field(self, api_schema: SchemaGenerator) -> None:
op = api_schema["paths"]["/api/share_link_bundles/{id}/rebuild/"]["post"]
resp_400 = op["responses"]["400"]["content"]["application/json"]["schema"]
ref = resp_400.get("$ref", "")
+1 -1
View File
@@ -943,7 +943,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=things")
self.assertIsNone(response.data["corrected_query"])
def test_search_spelling_suggestion_suppressed_for_private_terms(self):
def test_search_spelling_suggestion_suppressed_for_private_terms(self) -> None:
owner = User.objects.create_user("owner")
attacker = User.objects.create_user("attacker")
attacker.user_permissions.add(
+24 -39
View File
@@ -273,6 +273,7 @@ class TestApiWorkflows(DirectoriesMixin, APITestCase):
self.assertEqual(Workflow.objects.count(), 2)
workflow = Workflow.objects.get(name="Workflow 2")
trigger = workflow.triggers.first()
assert trigger is not None
self.assertSetEqual(
set(trigger.filter_has_tags.values_list("id", flat=True)),
{self.t1.id},
@@ -493,44 +494,24 @@ class TestApiWorkflows(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
workflow = Workflow.objects.get(id=response.data["id"])
self.assertEqual(workflow.name, "Workflow Updated")
self.assertEqual(workflow.triggers.first().filter_has_tags.first(), self.t1)
trigger = workflow.triggers.first()
assert trigger is not None
action = workflow.actions.first()
assert action is not None
self.assertEqual(trigger.filter_has_tags.first(), self.t1)
self.assertEqual(trigger.filter_has_all_tags.first(), self.t2)
self.assertEqual(trigger.filter_has_not_tags.first(), self.t3)
self.assertEqual(trigger.filter_has_any_correspondents.first(), self.c)
self.assertEqual(trigger.filter_has_not_correspondents.first(), self.c2)
self.assertEqual(trigger.filter_has_any_document_types.first(), self.dt)
self.assertEqual(trigger.filter_has_not_document_types.first(), self.dt2)
self.assertEqual(trigger.filter_has_any_storage_paths.first(), self.sp)
self.assertEqual(trigger.filter_has_not_storage_paths.first(), self.sp2)
self.assertEqual(
workflow.triggers.first().filter_has_all_tags.first(),
self.t2,
)
self.assertEqual(
workflow.triggers.first().filter_has_not_tags.first(),
self.t3,
)
self.assertEqual(
workflow.triggers.first().filter_has_any_correspondents.first(),
self.c,
)
self.assertEqual(
workflow.triggers.first().filter_has_not_correspondents.first(),
self.c2,
)
self.assertEqual(
workflow.triggers.first().filter_has_any_document_types.first(),
self.dt,
)
self.assertEqual(
workflow.triggers.first().filter_has_not_document_types.first(),
self.dt2,
)
self.assertEqual(
workflow.triggers.first().filter_has_any_storage_paths.first(),
self.sp,
)
self.assertEqual(
workflow.triggers.first().filter_has_not_storage_paths.first(),
self.sp2,
)
self.assertEqual(
workflow.triggers.first().filter_custom_field_query,
trigger.filter_custom_field_query,
json.dumps(["AND", [[self.cf1.id, "exact", "value"]]]),
)
self.assertEqual(workflow.actions.first().assign_title, "Action New Title")
self.assertEqual(action.assign_title, "Action New Title")
def test_api_update_workflow_no_trigger_actions(self) -> None:
"""
@@ -612,9 +593,13 @@ class TestApiWorkflows(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
workflow = Workflow.objects.get(id=response.data["id"])
self.assertEqual(WorkflowTrigger.objects.all().count(), 1)
self.assertNotEqual(workflow.triggers.first().id, self.trigger.id)
new_trigger = workflow.triggers.first()
assert new_trigger is not None
self.assertNotEqual(new_trigger.id, self.trigger.id)
self.assertEqual(WorkflowAction.objects.all().count(), 1)
self.assertNotEqual(workflow.actions.first().id, self.action.id)
new_action = workflow.actions.first()
assert new_action is not None
self.assertNotEqual(new_action.id, self.action.id)
def test_email_action_validation(self) -> None:
"""
@@ -873,7 +858,7 @@ class TestApiWorkflows(DirectoriesMixin, APITestCase):
self.action.refresh_from_db()
self.assertEqual(self.action.assign_title, "Patched Title")
def test_password_action_passwords_field(self):
def test_password_action_passwords_field(self) -> None:
"""
GIVEN:
- Nothing
@@ -896,7 +881,7 @@ class TestApiWorkflows(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data["passwords"], passwords)
def test_password_action_invalid_passwords_field(self):
def test_password_action_invalid_passwords_field(self) -> None:
"""
GIVEN:
- Nothing
+9 -3
View File
@@ -86,7 +86,7 @@ class TestBarcode(
self.assertDictEqual(separator_page_numbers, {1: False})
@override_settings(CONSUMER_ENABLE_ASN_BARCODE=True)
def test_asn_barcode_duplicate_in_trash_fails(self):
def test_asn_barcode_duplicate_in_trash_fails(self) -> None:
"""
GIVEN:
- A document with ASN barcode 123 is in the trash
@@ -585,6 +585,7 @@ class TestBarcode(
- The barcode config is used
"""
app_config = ApplicationConfiguration.objects.first()
assert app_config is not None
app_config.barcodes_enabled = True
app_config.barcode_string = "CUSTOM BARCODE"
app_config.save()
@@ -771,6 +772,7 @@ class TestAsnBarcode(DirectoriesMixin, SampleDirMixin, GetReaderPluginMixin, Tes
)
document = Document.objects.first()
assert document is not None
self.assertEqual(document.archive_serial_number, 123)
@@ -1059,11 +1061,15 @@ class TestTagBarcode(DirectoriesMixin, SampleDirMixin, GetReaderPluginMixin, Tes
doc2 = documents[1]
self.assertEqual(doc2.tags.count(), 1)
self.assertEqual(doc2.tags.first().name, "invoice")
_tag_1 = doc2.tags.first()
assert _tag_1 is not None
self.assertEqual(_tag_1.name, "invoice")
doc3 = documents[2]
self.assertEqual(doc3.tags.count(), 1)
self.assertEqual(doc3.tags.first().name, "receipt")
_tag_2 = doc3.tags.first()
assert _tag_2 is not None
self.assertEqual(_tag_2.name, "receipt")
@override_settings(
CONSUMER_ENABLE_TAG_BARCODE=True,
+28 -30
View File
@@ -319,8 +319,10 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
[self.doc3.id],
)
# assert reflect document link
_cf_1 = self.doc3.custom_fields.first()
assert _cf_1 is not None
self.assertEqual(
self.doc3.custom_fields.first().value,
_cf_1.value,
[self.doc2.id, self.doc1.id],
)
@@ -334,14 +336,12 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
add_custom_fields={},
remove_custom_fields=[cf3.id],
)
self.assertNotIn(
self.doc3.id,
self.doc1.custom_fields.filter(field=cf3).first().value,
)
self.assertNotIn(
self.doc3.id,
self.doc2.custom_fields.filter(field=cf3).first().value,
)
_cf_2 = self.doc1.custom_fields.filter(field=cf3).first()
assert _cf_2 is not None
self.assertNotIn(self.doc3.id, _cf_2.value)
_cf_3 = self.doc2.custom_fields.filter(field=cf3).first()
assert _cf_3 is not None
self.assertNotIn(self.doc3.id, _cf_3.value)
def test_modify_custom_fields_doclink_self_link(self) -> None:
"""
@@ -363,14 +363,12 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
remove_custom_fields=[],
)
self.assertEqual(
self.doc1.custom_fields.first().value,
[self.doc2.id],
)
self.assertEqual(
self.doc2.custom_fields.first().value,
[self.doc1.id],
)
_cf_4 = self.doc1.custom_fields.first()
assert _cf_4 is not None
self.assertEqual(_cf_4.value, [self.doc2.id])
_cf_5 = self.doc2.custom_fields.first()
assert _cf_5 is not None
self.assertEqual(_cf_5.value, [self.doc1.id])
def test_delete(self) -> None:
self.assertEqual(Document.objects.count(), 5)
@@ -693,7 +691,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self,
mock_consume_file,
mock_delete_documents,
):
) -> None:
"""
GIVEN:
- Existing documents
@@ -932,7 +930,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_chord,
mock_consume_file,
mock_delete_documents,
):
) -> None:
"""
GIVEN:
- Existing documents
@@ -1027,7 +1025,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_not_called()
@mock.patch("documents.tasks.consume_file.apply_async")
def test_rotate(self, mock_consume_delay):
def test_rotate(self, mock_consume_delay) -> None:
"""
GIVEN:
- Existing documents
@@ -1054,7 +1052,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self,
mock_pdf_save,
mock_consume_delay,
):
) -> None:
"""
GIVEN:
- Existing documents
@@ -1078,7 +1076,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
def test_rotate_non_pdf(
self,
mock_consume_delay,
):
) -> None:
"""
GIVEN:
- Existing documents
@@ -1105,7 +1103,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_open,
mock_consume_delay,
mock_magic,
):
) -> None:
Document.objects.create(
checksum="B-v1",
title="B version 1",
@@ -1128,7 +1126,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
def test_delete_pages(self, mock_magic, mock_pdf_save, mock_consume_delay):
def test_delete_pages(self, mock_magic, mock_pdf_save, mock_consume_delay) -> None:
"""
GIVEN:
- Existing documents
@@ -1159,7 +1157,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_open,
mock_consume_delay,
mock_magic,
):
) -> None:
Document.objects.create(
checksum="B-v1",
title="B version 1",
@@ -1181,7 +1179,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
def test_delete_pages_with_error(self, mock_pdf_save, mock_consume_delay):
def test_delete_pages_with_error(self, mock_pdf_save, mock_consume_delay) -> None:
"""
GIVEN:
- Existing documents
@@ -1300,7 +1298,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(self.doc2.archive_serial_number, 333)
@mock.patch("documents.tasks.consume_file.apply_async")
def test_edit_pdf_with_update_document(self, mock_consume_delay):
def test_edit_pdf_with_update_document(self, mock_consume_delay) -> None:
"""
GIVEN:
- A single existing PDF document
@@ -1338,7 +1336,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_new,
mock_consume_delay,
mock_magic,
):
) -> None:
Document.objects.create(
checksum="B-v1",
title="B version 1",
@@ -1416,7 +1414,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self,
mock_consume_file,
mock_group,
):
) -> None:
"""
GIVEN:
- Existing document
@@ -1446,7 +1444,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_mkdtemp,
mock_consume_delay,
mock_update_document,
):
) -> None:
doc = self.doc1
temp_dir = self.dirs.scratch_dir / "remove-password-update"
temp_dir.mkdir(parents=True, exist_ok=True)
+26
View File
@@ -246,6 +246,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
@@ -297,6 +298,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
@@ -316,6 +318,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
@@ -331,6 +334,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
@@ -347,6 +351,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
@@ -363,6 +368,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.document_type.id, dt.id)
self._assert_first_last_send_progress()
@@ -377,6 +383,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.storage_path.id, sp.id)
self._assert_first_last_send_progress()
@@ -393,6 +400,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIn(t1, document.tags.all())
self.assertNotIn(t2, document.tags.all())
@@ -419,6 +427,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
fields_used = [
field_instance.field for field_instance in document.custom_fields.all()
@@ -441,6 +450,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.archive_serial_number, 123)
self._assert_first_last_send_progress()
@@ -460,6 +470,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
now = timezone.now()
self.assertEqual(document.title, f"{c.name}{dt.name} {now.strftime('%m-%y')}")
@@ -475,6 +486,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.owner, testuser)
self._assert_first_last_send_progress()
@@ -493,6 +505,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
user_checker = ObjectPermissionChecker(testuser)
self.assertTrue(user_checker.has_perm("view_document", document))
@@ -565,6 +578,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
document.delete()
with self.assertRaisesMessage(ConsumerError, "document is in the trash"):
@@ -645,6 +659,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.title, "new docs")
self.assertEqual(document.filename, "none/new docs.pdf")
@@ -666,6 +681,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertIsNotNone(document)
assert document is not None
@@ -704,6 +720,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.title, "new docs")
self.assertIsNotNone(document.title)
@@ -724,6 +741,7 @@ class TestConsumer(
document = Document.objects.first()
assert document is not None
assert document is not None
self.assertEqual(document.version_label, "v1")
@@ -940,6 +958,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(document.correspondent, correspondent)
self.assertEqual(document.document_type, dtype)
@@ -957,6 +976,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self._assert_first_last_send_progress()
@@ -987,6 +1007,7 @@ class TestConsumer(
# Move the existing document to trash
document = Document.objects.first()
assert document is not None
document.delete()
dst = self.get_test_file()
@@ -1015,6 +1036,7 @@ class TestConsumer(
consumer.run()
document = Document.objects.first()
assert document is not None
self._assert_first_last_send_progress()
@@ -1173,6 +1195,7 @@ class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(
document.created,
@@ -1203,6 +1226,7 @@ class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(
document.created,
@@ -1233,6 +1257,7 @@ class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(
document.created,
@@ -1265,6 +1290,7 @@ class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
consumer.run()
document = Document.objects.first()
assert document is not None
self.assertEqual(
document.created,
+2 -2
View File
@@ -1514,7 +1514,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
Path("somepath/asn-201-400/asn-3xx/Does Matter.pdf"),
)
def test_template_related_context_keeps_legacy_string_coercion(self):
def test_template_related_context_keeps_legacy_string_coercion(self) -> None:
"""
GIVEN:
- A storage path template that uses related objects directly as strings
@@ -1908,7 +1908,7 @@ class TestCustomFieldFilenameUpdates(
self.assertLessEqual(m.call_count, 1)
@override_settings(FILENAME_FORMAT=None)
def test_overlong_storage_path_keeps_existing_filename(self):
def test_overlong_storage_path_keeps_existing_filename(self) -> None:
initial_filename = generate_filename(self.doc)
Document.objects.filter(pk=self.doc.pk).update(filename=str(initial_filename))
self.doc.refresh_from_db()
+5 -5
View File
@@ -83,7 +83,7 @@ class TestDateLocalization:
format_style: str,
locale_str: str,
expected_output: str,
):
) -> None:
"""
Tests `localize_date` with `date` objects across different locales and formats.
"""
@@ -144,7 +144,7 @@ class TestDateLocalization:
format_style: str,
locale_str: str,
expected_output: str,
):
) -> None:
# To handle the non-breaking space in French and other locales
result = localize_date(value, format_style, locale_str)
assert result.replace("\u202f", " ") == expected_output.replace("\u202f", " ")
@@ -161,7 +161,7 @@ class TestDateLocalization:
def test_localize_date_raises_type_error_for_invalid_input(
self,
invalid_value: None | list[object] | dict[Any, Any] | Literal[1698330605],
):
) -> None:
with pytest.raises(TypeError) as excinfo:
localize_date(invalid_value, "medium", "en_US")
@@ -228,7 +228,7 @@ class TestDateLocalization:
format_style: str,
locale_str: str,
expected_output: str,
):
) -> None:
"""
Tests `localize_date` with `date` string across different locales and formats.
"""
@@ -289,7 +289,7 @@ class TestDateLocalization:
format_style: str,
locale_str: str,
expected_output: str,
):
) -> None:
"""
Tests `localize_date` with `date` string across different locales and formats.
"""
+18 -13
View File
@@ -10,10 +10,10 @@ from documents.regex import validate_regex_pattern
class TestValidateRegexPattern:
def test_valid_pattern(self):
def test_valid_pattern(self) -> None:
validate_regex_pattern(r"\d+")
def test_invalid_pattern_raises(self):
def test_invalid_pattern_raises(self) -> None:
with pytest.raises(ValueError):
validate_regex_pattern(r"[invalid")
@@ -40,7 +40,7 @@ class TestSafeRegexSearchAndMatch:
),
],
)
def test_match_found(self, func, pattern, text, expected_group):
def test_match_found(self, func, pattern, text, expected_group) -> None:
result = func(pattern, text)
assert result is not None
assert result.group() == expected_group
@@ -52,7 +52,7 @@ class TestSafeRegexSearchAndMatch:
pytest.param(safe_regex_match, r"\d+", "abc123", id="match-no-match"),
],
)
def test_no_match(self, func, pattern, text):
def test_no_match(self, func, pattern, text) -> None:
assert func(pattern, text) is None
@pytest.mark.parametrize(
@@ -62,7 +62,7 @@ class TestSafeRegexSearchAndMatch:
pytest.param(safe_regex_match, id="match"),
],
)
def test_invalid_pattern_returns_none(self, func):
def test_invalid_pattern_returns_none(self, func) -> None:
assert func(r"[invalid", "test") is None
@pytest.mark.parametrize(
@@ -72,7 +72,7 @@ class TestSafeRegexSearchAndMatch:
pytest.param(safe_regex_match, id="match"),
],
)
def test_flags_respected(self, func):
def test_flags_respected(self, func) -> None:
assert func(r"abc", "ABC", flags=regex.IGNORECASE) is not None
@pytest.mark.parametrize(
@@ -82,7 +82,12 @@ class TestSafeRegexSearchAndMatch:
pytest.param(safe_regex_match, "match", id="match"),
],
)
def test_timeout_returns_none(self, func, method_name, mocker: MockerFixture):
def test_timeout_returns_none(
self,
func,
method_name,
mocker: MockerFixture,
) -> None:
mock_compile = mocker.patch("documents.regex.regex.compile")
getattr(mock_compile.return_value, method_name).side_effect = TimeoutError
assert func(r"\d+", "test") is None
@@ -97,31 +102,31 @@ class TestSafeRegexSub:
pytest.param(r"abc", "X", "ABC", "X", id="flags"),
],
)
def test_substitution(self, pattern, repl, text, expected):
def test_substitution(self, pattern, repl, text, expected) -> None:
flags = regex.IGNORECASE if pattern == r"abc" else 0
result = safe_regex_sub(pattern, repl, text, flags=flags)
assert result == expected
def test_invalid_pattern_returns_none(self):
def test_invalid_pattern_returns_none(self) -> None:
assert safe_regex_sub(r"[invalid", "x", "test") is None
def test_timeout_returns_none(self, mocker: MockerFixture):
def test_timeout_returns_none(self, mocker: MockerFixture) -> None:
mock_compile = mocker.patch("documents.regex.regex.compile")
mock_compile.return_value.sub.side_effect = TimeoutError
assert safe_regex_sub(r"\d+", "X", "test") is None
class TestSafeRegexFinditer:
def test_yields_matches(self):
def test_yields_matches(self) -> None:
pattern = regex.compile(r"\d+")
matches = list(safe_regex_finditer(pattern, "a1b22c333"))
assert [m.group() for m in matches] == ["1", "22", "333"]
def test_no_matches(self):
def test_no_matches(self) -> None:
pattern = regex.compile(r"\d+")
assert list(safe_regex_finditer(pattern, "abcdef")) == []
def test_timeout_stops_iteration(self, mocker: MockerFixture):
def test_timeout_stops_iteration(self, mocker: MockerFixture) -> None:
mock_pattern = mocker.MagicMock()
mock_pattern.finditer.side_effect = TimeoutError
mock_pattern.pattern = r"\d+"
@@ -280,6 +280,7 @@ class ShareLinkBundleBuildTaskTests(DirectoriesMixin, APITestCase):
self.document.archive_filename = f"{self.document.pk:07}.pdf"
self.document.save()
path = self.document.archive_path
assert path is not None
else:
path = self.document.source_path
path.parent.mkdir(parents=True, exist_ok=True)
@@ -304,6 +305,7 @@ class ShareLinkBundleBuildTaskTests(DirectoriesMixin, APITestCase):
self.assertGreater(bundle.size_bytes or 0, 0)
final_path = bundle.absolute_file_path
self.assertIsNotNone(final_path)
assert final_path is not None
self.assertTrue(final_path.exists())
with zipfile.ZipFile(final_path) as zipf:
names = zipf.namelist()
@@ -327,6 +329,7 @@ class ShareLinkBundleBuildTaskTests(DirectoriesMixin, APITestCase):
bundle.refresh_from_db()
final_path = bundle.absolute_file_path
self.assertIsNotNone(final_path)
assert final_path is not None
self.assertTrue(final_path.exists())
self.assertNotEqual(final_path.read_bytes(), b"old")
@@ -354,6 +357,7 @@ class ShareLinkBundleBuildTaskTests(DirectoriesMixin, APITestCase):
bundle.refresh_from_db()
self.assertEqual(bundle.status, ShareLinkBundle.Status.FAILED)
self.assertIsInstance(bundle.last_error, dict)
assert isinstance(bundle.last_error, dict)
self.assertEqual(bundle.last_error.get("message"), "zip failure")
self.assertEqual(bundle.last_error.get("exception_type"), "RuntimeError")
scratch_zips = list(Path(settings.SCRATCH_DIR).glob("*.zip"))
+28 -24
View File
@@ -56,7 +56,11 @@ def send_publish(
@pytest.mark.django_db
class TestBeforeTaskPublishHandler:
def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides):
def test_creates_task_for_consume_file(
self,
consume_input_doc,
consume_overrides,
) -> None:
task_id = send_publish(
"documents.tasks.consume_file",
(),
@@ -70,18 +74,18 @@ class TestBeforeTaskPublishHandler:
assert task.input_data["filename"] == "invoice.pdf"
assert task.owner_id == consume_overrides.owner_id
def test_creates_task_for_train_classifier(self):
def test_creates_task_for_train_classifier(self) -> None:
task_id = send_publish("documents.tasks.train_classifier", (), {})
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.TRAIN_CLASSIFIER
assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL
def test_creates_task_for_sanity_check(self):
def test_creates_task_for_sanity_check(self) -> None:
task_id = send_publish("documents.tasks.sanity_check", (), {})
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.SANITY_CHECK
def test_creates_task_for_process_mail_accounts(self):
def test_creates_task_for_process_mail_accounts(self) -> None:
task_id = send_publish(
"paperless_mail.tasks.process_mail_accounts",
(),
@@ -91,13 +95,13 @@ class TestBeforeTaskPublishHandler:
assert task.task_type == PaperlessTask.TaskType.MAIL_FETCH
assert task.input_data["account_ids"] == [1, 2]
def test_mail_fetch_no_account_ids_stores_empty_input(self):
def test_mail_fetch_no_account_ids_stores_empty_input(self) -> None:
"""Beat-scheduled mail checks pass no account_ids; input_data should be {} not {"account_ids": None}."""
task_id = send_publish("paperless_mail.tasks.process_mail_accounts", (), {})
task = PaperlessTask.objects.get(task_id=task_id)
assert task.input_data == {}
def test_overrides_date_serialized_as_iso_string(self, consume_input_doc):
def test_overrides_date_serialized_as_iso_string(self, consume_input_doc) -> None:
"""A datetime.date in overrides is stored as an ISO string so input_data is JSON-safe."""
overrides = DocumentMetadataOverrides(created=datetime.date(2024, 1, 15))
@@ -110,7 +114,7 @@ class TestBeforeTaskPublishHandler:
task = PaperlessTask.objects.get(task_id=task_id)
assert task.input_data["overrides"]["created"] == "2024-01-15"
def test_overrides_path_serialized_as_string(self, consume_input_doc):
def test_overrides_path_serialized_as_string(self, consume_input_doc) -> None:
"""A Path value in overrides is stored as a plain string so input_data is JSON-safe."""
overrides = DocumentMetadataOverrides()
overrides.filename = Path("/uploads/invoice.pdf") # type: ignore[assignment]
@@ -159,11 +163,11 @@ class TestBeforeTaskPublishHandler:
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == expected_trigger_source
def test_ignores_untracked_task(self):
def test_ignores_untracked_task(self) -> None:
send_publish("documents.tasks.some_untracked_task", (), {})
assert PaperlessTask.objects.count() == 0
def test_ignores_none_headers(self):
def test_ignores_none_headers(self) -> None:
before_task_publish_handler(sender=None, headers=None, body=None)
assert PaperlessTask.objects.count() == 0
@@ -185,7 +189,7 @@ class TestBeforeTaskPublishHandler:
@pytest.mark.django_db
class TestTaskPrerunHandler:
def test_marks_task_started(self):
def test_marks_task_started(self) -> None:
task = PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
task_prerun_handler(task_id=task.task_id)
@@ -215,7 +219,7 @@ class TestTaskPostrunHandler:
date_started=timezone.now(),
)
def test_records_success_with_dict_result(self):
def test_records_success_with_dict_result(self) -> None:
task = self._started_task()
task_postrun_handler(
@@ -230,7 +234,7 @@ class TestTaskPostrunHandler:
assert task.duration_seconds is not None
assert task.wait_time_seconds is not None
def test_skips_failure_state(self):
def test_skips_failure_state(self) -> None:
"""postrun skips FAILURE; task_failure_handler owns that path."""
task = self._started_task()
@@ -238,7 +242,7 @@ class TestTaskPostrunHandler:
task.refresh_from_db()
assert task.status == PaperlessTask.Status.STARTED
def test_records_success_with_consume_result(self):
def test_records_success_with_consume_result(self) -> None:
"""ConsumeFileSuccessResult dict is stored directly as result_data."""
from documents.data_models import ConsumeFileSuccessResult
@@ -251,7 +255,7 @@ class TestTaskPostrunHandler:
task.refresh_from_db()
assert task.result_data == {"document_id": 42}
def test_records_stopped_with_reason(self):
def test_records_stopped_with_reason(self) -> None:
"""ConsumeFileStoppedResult dict is stored directly as result_data."""
from documents.data_models import ConsumeFileStoppedResult
@@ -264,14 +268,14 @@ class TestTaskPostrunHandler:
task.refresh_from_db()
assert task.result_data == {"reason": "Barcode splitting complete!"}
def test_none_retval_stores_no_result_data(self):
def test_none_retval_stores_no_result_data(self) -> None:
"""None return value (non-consume tasks) leaves result_data untouched."""
task = self._started_task()
task_postrun_handler(task_id=task.task_id, retval=None, state="SUCCESS")
task.refresh_from_db()
assert task.result_data is None
def test_ignores_unknown_task_id(self):
def test_ignores_unknown_task_id(self) -> None:
task_postrun_handler(
task_id="nonexistent",
@@ -279,7 +283,7 @@ class TestTaskPostrunHandler:
state="SUCCESS",
) # must not raise
def test_records_revoked_state(self):
def test_records_revoked_state(self) -> None:
task = self._started_task()
task_postrun_handler(task_id=task.task_id, retval=None, state="REVOKED")
@@ -289,7 +293,7 @@ class TestTaskPostrunHandler:
@pytest.mark.django_db
class TestTaskFailureHandler:
def test_records_failure_with_exception(self):
def test_records_failure_with_exception(self) -> None:
task = PaperlessTaskFactory(
task_type=PaperlessTask.TaskType.CONSUME_FILE,
@@ -308,7 +312,7 @@ class TestTaskFailureHandler:
assert task.result_data["error_message"] == "PDF parse failed"
assert task.date_done is not None
def test_records_traceback_when_provided(self):
def test_records_traceback_when_provided(self) -> None:
task = PaperlessTaskFactory(
task_type=PaperlessTask.TaskType.CONSUME_FILE,
@@ -331,7 +335,7 @@ class TestTaskFailureHandler:
assert "traceback" in task.result_data
assert len(task.result_data["traceback"]) <= 5000
def test_computes_duration_and_wait_time(self):
def test_computes_duration_and_wait_time(self) -> None:
now = timezone.now()
task = PaperlessTaskFactory(
@@ -350,14 +354,14 @@ class TestTaskFailureHandler:
assert task.duration_seconds == pytest.approx(5.0, abs=1.0)
assert task.wait_time_seconds == pytest.approx(5.0, abs=1.0)
def test_ignores_none_task_id(self):
def test_ignores_none_task_id(self) -> None:
task_failure_handler(task_id=None, exception=ValueError("x"), traceback=None)
@pytest.mark.django_db
class TestTaskRevokedHandler:
def test_marks_task_revoked(self, mocker: pytest_mock.MockerFixture):
def test_marks_task_revoked(self, mocker: pytest_mock.MockerFixture) -> None:
"""task_revoked_handler moves a queued task to REVOKED and stamps date_done."""
task = PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
request = mocker.MagicMock()
@@ -368,12 +372,12 @@ class TestTaskRevokedHandler:
assert task.status == PaperlessTask.Status.REVOKED
assert task.date_done is not None
def test_ignores_none_request(self):
def test_ignores_none_request(self) -> None:
"""task_revoked_handler must not raise when request is None."""
task_revoked_handler(request=None) # must not raise
def test_ignores_unknown_task_id(self, mocker: pytest_mock.MockerFixture):
def test_ignores_unknown_task_id(self, mocker: pytest_mock.MockerFixture) -> None:
"""task_revoked_handler must not raise for a task_id not in the database."""
request = mocker.MagicMock()
request.id = "nonexistent-id"
+54 -68
View File
@@ -187,6 +187,7 @@ class TestWorkflows(
)
document = Document.objects.first()
assert document is not None
self.assertEqual(document.correspondent, self.c)
self.assertEqual(document.document_type, self.dt)
self.assertEqual(list(document.tags.all()), [self.t1, self.t2, self.t3])
@@ -298,6 +299,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertEqual(document.correspondent, self.c)
self.assertEqual(document.document_type, self.dt)
self.assertEqual(list(document.tags.all()), [self.t1, self.t2, self.t3])
@@ -415,6 +417,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
# workflow 1
self.assertEqual(document.document_type, self.dt)
# workflow 2
@@ -483,6 +486,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertEqual(document.title, "Doc fnmatch title")
expected_str = f"Document matched {trigger} from {w}"
@@ -535,6 +539,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
self.assertEqual(document.tags.all().count(), 0)
@@ -547,7 +552,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(get_groups_with_perms(document).count(), 0)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
get_users_with_perms(
document,
@@ -555,7 +561,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(get_groups_with_perms(document).count(), 0)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(document.title, "simple")
expected_str = f"Document did not match {w}"
@@ -609,6 +616,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
self.assertEqual(document.tags.all().count(), 0)
@@ -621,12 +629,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
get_users_with_perms(
document,
@@ -634,12 +638,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(document.title, "simple")
expected_str = f"Document did not match {w}"
@@ -696,6 +696,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
self.assertEqual(document.tags.all().count(), 0)
@@ -708,12 +709,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
get_users_with_perms(
document,
@@ -721,12 +718,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(document.title, "simple")
expected_str = f"Document did not match {w}"
@@ -780,6 +773,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
self.assertEqual(document.tags.all().count(), 0)
@@ -792,12 +786,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
get_users_with_perms(
document,
@@ -805,12 +795,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(document.title, "simple")
expected_str = f"Document did not match {w}"
@@ -898,6 +884,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertEqual(
list(document.custom_fields.all().values_list("field", flat=True)),
[self.cf1.pk],
@@ -1968,6 +1955,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertRegex(
document.title,
r"Doc added in \w{3,}",
@@ -2064,11 +2052,11 @@ class TestWorkflows(
format="json",
)
view_users_perms: QuerySet = get_users_with_perms(
view_users_perms: QuerySet[Any] = get_users_with_perms(
doc,
only_with_perms_in=["view_document"],
)
change_users_perms: QuerySet = get_users_with_perms(
change_users_perms: QuerySet[Any] = get_users_with_perms(
doc,
only_with_perms_in=["change_document"],
)
@@ -2079,7 +2067,7 @@ class TestWorkflows(
self.assertIn(self.user3, view_users_perms)
self.assertIn(self.user3, change_users_perms)
group_perms: QuerySet = get_groups_with_perms(doc)
group_perms: QuerySet[Any] = get_groups_with_perms(doc)
# group1 should still have permissions
self.assertIn(self.group1, group_perms)
# group2 should have been added
@@ -2845,7 +2833,7 @@ class TestWorkflows(
self.assertEqual(doc.custom_fields.all().count(), 0)
self.assertFalse(self.user3.has_perm("documents.view_document", doc))
self.assertFalse(self.user3.has_perm("documents.change_document", doc))
group_perms: QuerySet = get_groups_with_perms(doc)
group_perms: QuerySet[Any] = get_groups_with_perms(doc)
self.assertNotIn(self.group1, group_perms)
def test_document_updated_workflow_assignment_persists_when_removing_trigger_tag(
@@ -2979,7 +2967,7 @@ class TestWorkflows(
self.assertEqual(doc.custom_fields.all().count(), 0)
self.assertFalse(self.user3.has_perm("documents.view_document", doc))
self.assertFalse(self.user3.has_perm("documents.change_document", doc))
group_perms: QuerySet = get_groups_with_perms(doc)
group_perms: QuerySet[Any] = get_groups_with_perms(doc)
self.assertNotIn(self.group1, group_perms)
def test_removal_action_document_consumed(self) -> None:
@@ -3057,6 +3045,7 @@ class TestWorkflows(
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
@@ -3179,6 +3168,7 @@ class TestWorkflows(
None,
)
document = Document.objects.first()
assert document is not None
self.assertIsNone(document.correspondent)
self.assertIsNone(document.document_type)
self.assertEqual(document.tags.all().count(), 0)
@@ -3192,12 +3182,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
get_users_with_perms(
document,
@@ -3205,12 +3191,8 @@ class TestWorkflows(
).count(),
0,
)
self.assertEqual(
get_groups_with_perms(
document,
).count(),
0,
)
group_perms: QuerySet[Any] = get_groups_with_perms(document)
self.assertEqual(group_perms.count(), 0)
self.assertEqual(
document.custom_fields.all()
.values_list(
@@ -3279,7 +3261,10 @@ class TestWorkflows(
PAPERLESS_URL="http://localhost:8000",
)
@mock.patch("django.core.mail.message.EmailMessage.send")
def test_workflow_assignment_then_email_includes_attachment(self, mock_email_send):
def test_workflow_assignment_then_email_includes_attachment(
self,
mock_email_send,
) -> None:
"""
GIVEN:
- Workflow with assignment and email actions
@@ -3828,7 +3813,7 @@ class TestWorkflows(
def test_workflow_webhook_action_does_not_overwrite_concurrent_tags(
self,
mock_execute_webhook_action,
):
) -> None:
"""
GIVEN:
- A document updated workflow with only a webhook action
@@ -3882,7 +3867,7 @@ class TestWorkflows(
def test_workflow_tag_actions_do_not_overwrite_concurrent_tags(
self,
mock_execute_webhook_action,
):
) -> None:
"""
GIVEN:
- A document updated workflow that clears tags and assigns an inbox tag
@@ -4160,7 +4145,7 @@ class TestWorkflows(
def test_password_removal_action_attempts_multiple_passwords(
self,
mock_remove_password,
):
) -> None:
"""
GIVEN:
- Workflow password removal action
@@ -4214,7 +4199,7 @@ class TestWorkflows(
def test_password_removal_action_fails_without_correct_password(
self,
mock_remove_password,
):
) -> None:
"""
GIVEN:
- Workflow password removal action
@@ -4247,7 +4232,7 @@ class TestWorkflows(
def test_password_removal_action_skips_without_passwords(
self,
mock_remove_password,
):
) -> None:
"""
GIVEN:
- Workflow password removal action with no passwords
@@ -4279,7 +4264,7 @@ class TestWorkflows(
def test_password_removal_consumable_document_deferred(
self,
mock_remove_password,
):
) -> None:
"""
GIVEN:
- Workflow password removal action
@@ -4346,7 +4331,7 @@ class TestWorkflows(
)
assert mock_remove_password.call_count == 2
def test_workflow_trash_action_soft_delete(self):
def test_workflow_trash_action_soft_delete(self) -> None:
"""
GIVEN:
- Document updated workflow with delete action
@@ -4389,7 +4374,7 @@ class TestWorkflows(
PAPERLESS_URL="http://localhost:8000",
)
@mock.patch("django.core.mail.message.EmailMessage.send")
def test_workflow_trash_with_email_action(self, mock_email_send):
def test_workflow_trash_with_email_action(self, mock_email_send) -> None:
"""
GIVEN:
- Workflow with email action, then move to trash action
@@ -4444,7 +4429,7 @@ class TestWorkflows(
PAPERLESS_URL="http://localhost:8000",
)
@mock.patch("documents.workflows.webhooks.send_webhook.apply_async")
def test_workflow_trash_with_webhook_action(self, mock_webhook_delay):
def test_workflow_trash_with_webhook_action(self, mock_webhook_delay) -> None:
"""
GIVEN:
- Workflow with webhook action (include_document=True), then move to trash action
@@ -4577,7 +4562,7 @@ class TestWorkflows(
self.assertEqual(Document.objects.count(), 0)
self.assertEqual(Document.deleted_objects.count(), 1)
def test_multiple_workflows_trash_then_assignment(self):
def test_multiple_workflows_trash_then_assignment(self) -> None:
"""
GIVEN:
- Workflow 1 (order=0) with move to trash action
@@ -4646,7 +4631,7 @@ class TestWorkflows(
log_output,
)
def test_workflow_delete_action_during_consumption(self):
def test_workflow_delete_action_during_consumption(self) -> None:
"""
GIVEN:
- Workflow with consumption trigger and delete action
@@ -4705,7 +4690,7 @@ class TestWorkflows(
# No document should be created
self.assertEqual(Document.objects.count(), 0)
def test_workflow_delete_action_during_consumption_with_assignment(self):
def test_workflow_delete_action_during_consumption_with_assignment(self) -> None:
"""
GIVEN:
- Workflow with consumption trigger, assignment action, then delete action
@@ -5219,4 +5204,5 @@ class TestDateWorkflowLocalization(
None,
)
document = Document.objects.first()
assert document is not None
assert document.title == expected_title
+7 -7
View File
@@ -184,22 +184,22 @@ class FileSystemAssertsMixin:
Utilities for checks various state information of the file system
"""
def assertIsFile(self, path: PathLike | str) -> None:
def assertIsFile(self, path: PathLike[str] | str) -> None:
self.assertTrue(Path(path).resolve().is_file(), f"File does not exist: {path}")
def assertIsNotFile(self, path: PathLike | str) -> None:
def assertIsNotFile(self, path: PathLike[str] | str) -> None:
self.assertFalse(Path(path).resolve().is_file(), f"File does exist: {path}")
def assertIsDir(self, path: PathLike | str) -> None:
def assertIsDir(self, path: PathLike[str] | str) -> None:
self.assertTrue(Path(path).resolve().is_dir(), f"Dir does not exist: {path}")
def assertIsNotDir(self, path: PathLike | str) -> None:
def assertIsNotDir(self, path: PathLike[str] | str) -> None:
self.assertFalse(Path(path).resolve().is_dir(), f"Dir does exist: {path}")
def assertFilesEqual(
self,
path1: PathLike | str,
path2: PathLike | str,
path1: PathLike[str] | str,
path2: PathLike[str] | str,
) -> None:
path1 = Path(path1)
path2 = Path(path2)
@@ -210,7 +210,7 @@ class FileSystemAssertsMixin:
self.assertEqual(hash1, hash2, "File SHA256 mismatch")
def assertFileCountInDir(self, path: PathLike | str, count: int) -> None:
def assertFileCountInDir(self, path: PathLike[str] | str, count: int) -> None:
path = Path(path).resolve()
self.assertTrue(path.is_dir(), f"Path {path} is not a directory")
files = [x for x in path.iterdir() if x.is_file()]
+1 -1
View File
@@ -2009,7 +2009,7 @@ class DocumentViewSet(
)
class ChatStreamingSerializer(serializers.Serializer):
class ChatStreamingSerializer(serializers.Serializer[dict[str, Any]]):
q = serializers.CharField(required=True)
document_id = serializers.IntegerField(required=False, allow_null=True)
+2 -4
View File
@@ -1,6 +1,4 @@
import grp
import os
import pwd
import shutil
import stat
import subprocess
@@ -38,8 +36,8 @@ def path_check(var: str, directory: Path) -> list[Error]:
except PermissionError:
dir_stat: os.stat_result = Path(directory).stat()
dir_mode: str = stat.filemode(dir_stat.st_mode)
dir_owner: str = pwd.getpwuid(dir_stat.st_uid).pw_name
dir_group: str = grp.getgrgid(dir_stat.st_gid).gr_name
dir_owner: str = ""
dir_group: str = ""
messages.append(
Error(
writeable_message.format(var),
+8 -5
View File
@@ -38,7 +38,9 @@ class OutputTypeConfig(BaseConfig):
def __post_init__(self) -> None:
app_config = self._get_config_instance()
self.output_type = app_config.output_type or settings.OCR_OUTPUT_TYPE
self.output_type = app_config.output_type or OutputTypeChoices(
settings.OCR_OUTPUT_TYPE,
)
@dataclasses.dataclass
@@ -70,12 +72,13 @@ class OcrConfig(OutputTypeConfig):
self.pages = app_config.pages or settings.OCR_PAGES
self.language = app_config.language or settings.OCR_LANGUAGE
self.mode = app_config.mode or settings.OCR_MODE
self.mode = app_config.mode or ModeChoices(settings.OCR_MODE)
self.archive_file_generation = (
app_config.archive_file_generation or settings.ARCHIVE_FILE_GENERATION
app_config.archive_file_generation
or ArchiveFileGenerationChoices(settings.ARCHIVE_FILE_GENERATION)
)
self.image_dpi = app_config.image_dpi or settings.OCR_IMAGE_DPI
self.clean = app_config.unpaper_clean or settings.OCR_CLEAN
self.clean = app_config.unpaper_clean or CleanChoices(settings.OCR_CLEAN)
self.deskew = (
app_config.deskew if app_config.deskew is not None else settings.OCR_DESKEW
)
@@ -92,7 +95,7 @@ class OcrConfig(OutputTypeConfig):
)
self.color_conversion_strategy = (
app_config.color_conversion_strategy
or settings.OCR_COLOR_CONVERSION_STRATEGY
or ColorConvertChoices(settings.OCR_COLOR_CONVERSION_STRATEGY)
)
user_args = None
@@ -766,7 +766,11 @@ class TestParser:
content=b"Pretend merged PDF content",
)
def test_layout_option(layout_option, expected_calls, expected_pdf_names):
def test_layout_option(
layout_option,
expected_calls,
expected_pdf_names,
) -> None:
mock_mailrule_get.return_value = mock.Mock(pdf_layout=layout_option)
mail_parser.configure(ParserContext(mailrule_id=1))
mail_parser.parse(
@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING
import pytest
from paperless.models import ModeChoices
if TYPE_CHECKING:
from pytest_mock import MockerFixture
@@ -72,7 +74,7 @@ class TestAutoModeWithText:
)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
simple_digital_pdf_file,
"application/pdf",
@@ -106,7 +108,7 @@ class TestAutoModeWithText:
)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
simple_digital_pdf_file,
"application/pdf",
@@ -158,7 +160,7 @@ class TestAutoModeNoText:
mocker.patch.object(tesseract_parser, "extract_text", side_effect=_extract_side)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
multi_page_images_pdf_file,
"application/pdf",
@@ -200,7 +202,7 @@ class TestAutoModeNoText:
mocker.patch.object(tesseract_parser, "extract_text", side_effect=_extract_side)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
multi_page_images_pdf_file,
"application/pdf",
@@ -243,7 +245,7 @@ class TestOffModePdf:
)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "off"
tesseract_parser.settings.mode = ModeChoices.OFF
tesseract_parser.parse(
simple_digital_pdf_file,
"application/pdf",
@@ -283,7 +285,7 @@ class TestOffModePdf:
)
mocker.patch("ocrmypdf.pdfa.generate_pdfa_ps")
tesseract_parser.settings.mode = "off"
tesseract_parser.settings.mode = ModeChoices.OFF
tesseract_parser.parse(
simple_digital_pdf_file,
"application/pdf",
@@ -323,7 +325,7 @@ class TestOffModeImage:
"""
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "off"
tesseract_parser.settings.mode = ModeChoices.OFF
tesseract_parser.parse(simple_png_file, "image/png", produce_archive=False)
mock_ocr.assert_not_called()
@@ -355,7 +357,7 @@ class TestOffModeImage:
)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "off"
tesseract_parser.settings.mode = ModeChoices.OFF
tesseract_parser.parse(simple_png_file, "image/png", produce_archive=True)
mock_convert.assert_called_once_with(simple_png_file)
@@ -429,7 +431,7 @@ class TestProduceArchiveFalse:
)
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
simple_digital_pdf_file,
"application/pdf",
@@ -44,6 +44,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_PAGES=10):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.pages = 5
instance.save()
@@ -62,6 +63,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_LANGUAGE="eng+deu"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.language = "fra+ita"
instance.save()
@@ -80,6 +82,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_OUTPUT_TYPE="pdfa-3"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.output_type = OutputTypeChoices.PDF_A
instance.save()
@@ -100,6 +103,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
# AUTO mode with skip_text=True explicitly passed: skip_text is set
with override_settings(OCR_MODE="redo"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.mode = ModeChoices.AUTO
instance.save()
@@ -118,6 +122,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
# AUTO mode alone (no skip_text): no extra OCR flag is set
with override_settings(OCR_MODE="redo"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.mode = ModeChoices.AUTO
instance.save()
@@ -138,6 +143,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_CLEAN="clean-final"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.unpaper_clean = CleanChoices.CLEAN
instance.save()
@@ -147,6 +153,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
with override_settings(OCR_CLEAN="clean-final"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.unpaper_clean = CleanChoices.FINAL
instance.save()
@@ -166,6 +173,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_DESKEW=False):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.deskew = True
instance.save()
@@ -185,6 +193,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
with override_settings(OCR_ROTATE_PAGES=False, OCR_ROTATE_PAGES_THRESHOLD=30.0):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
assert instance is not None
instance.rotate_pages = True
instance.rotate_pages_threshold = 15.0
instance.save()
@@ -205,6 +214,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_MAX_IMAGE_PIXELS=2_000_000.0):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.max_image_pixels = 1_000_000.0
instance.save()
@@ -223,6 +233,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
"""
with override_settings(OCR_COLOR_CONVERSION_STRATEGY="LeaveColorUnchanged"):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.color_conversion_strategy = ColorConvertChoices.INDEPENDENT
instance.save()
@@ -246,6 +257,7 @@ class TestParserSettingsFromDb(DirectoriesMixin, FileSystemAssertsMixin, TestCas
OCR_USER_ARGS=json.dumps({"continue_on_soft_render_error": True}),
):
instance = ApplicationConfiguration.objects.all().first()
assert instance is not None
instance.user_args = {"unpaper_args": "--pre-rotate 90"}
instance.save()
@@ -18,6 +18,7 @@ from ocrmypdf import SubprocessOutputError
from documents.parsers import ParseError
from documents.parsers import run_convert
from paperless.models import ModeChoices
from paperless.parsers import ParserProtocol
from paperless.parsers.tesseract import RasterisedDocumentParser
from paperless.parsers.tesseract import post_process_text
@@ -387,8 +388,10 @@ class TestParsePdf:
)
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -413,7 +416,7 @@ class TestParsePdf:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "redo"
tesseract_parser.settings.mode = ModeChoices.REDO
tesseract_parser.parse(
tesseract_samples_dir / "with-form.pdf",
"application/pdf",
@@ -430,7 +433,7 @@ class TestParsePdf:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "force"
tesseract_parser.settings.mode = ModeChoices.FORCE
tesseract_parser.parse(
tesseract_samples_dir / "with-form.pdf",
"application/pdf",
@@ -445,7 +448,7 @@ class TestParsePdf:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(tesseract_samples_dir / "signed.pdf", "application/pdf")
assert tesseract_parser.archive_path is None
assert_ordered_substrings(
@@ -461,7 +464,7 @@ class TestParsePdf:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "encrypted.pdf",
"application/pdf",
@@ -530,7 +533,9 @@ class TestParseImages:
tesseract_parser.parse(tesseract_samples_dir / "simple-no-dpi.png", "image/png")
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
assert "this is a test document." in tesseract_parser.get_text().lower()
_text = tesseract_parser.get_text()
assert _text is not None
assert "this is a test document." in _text.lower()
def test_no_dpi_no_fallback_raises(
self,
@@ -563,8 +568,10 @@ class TestParseMultiPage:
)
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -589,8 +596,10 @@ class TestParseMultiPage:
"application/pdf",
)
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -599,14 +608,16 @@ class TestParseMultiPage:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-images.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -626,13 +637,15 @@ class TestParseMultiPage:
- Pages 1 and 2 extracted; page 3 absent
"""
tesseract_parser.settings.pages = 2
tesseract_parser.settings.mode = "redo"
tesseract_parser.settings.mode = ModeChoices.REDO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-images.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
text = tesseract_parser.get_text().lower()
text = tesseract_parser.get_text()
assert text is not None
text = text.lower()
assert_ordered_substrings(text, ["page 1", "page 2"])
assert "page 3" not in text
@@ -652,13 +665,15 @@ class TestParseMultiPage:
- Only page 1 extracted
"""
tesseract_parser.settings.pages = 1
tesseract_parser.settings.mode = "force"
tesseract_parser.settings.mode = ModeChoices.FORCE
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-images.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
text = tesseract_parser.get_text().lower()
text = tesseract_parser.get_text()
assert text is not None
text = text.lower()
assert "page 1" in text
assert "page 2" not in text
assert "page 3" not in text
@@ -681,8 +696,10 @@ class TestParseMultiPage:
"image/tiff",
)
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -704,8 +721,10 @@ class TestParseMultiPage:
shutil.copy(tesseract_samples_dir / "multi-page-images-alpha.tiff", dest)
tesseract_parser.parse(dest, "image/tiff")
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -727,8 +746,10 @@ class TestParseMultiPage:
shutil.copy(tesseract_samples_dir / "multi-page-images-alpha-rgb.tiff", dest)
tesseract_parser.parse(dest, "image/tiff")
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -754,15 +775,17 @@ class TestSkipArchive:
- Text extracted from original; no archive created (text exists +
produce_archive=False skips OCRmyPDF entirely)
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-digital.pdf",
"application/pdf",
produce_archive=False,
)
assert tesseract_parser.archive_path is None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -780,14 +803,16 @@ class TestSkipArchive:
THEN:
- Text extracted; archive created (OCR needed, no existing text)
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-images.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3"],
)
@@ -838,13 +863,15 @@ class TestSkipArchive:
- archive_path is set if and only if produce_archive=True
- Text is always extracted
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / filename,
"application/pdf",
produce_archive=produce_archive,
)
text = tesseract_parser.get_text().lower()
text = tesseract_parser.get_text()
assert text is not None
text = text.lower()
assert_ordered_substrings(text, ["page 1", "page 2", "page 3"])
if expect_archive:
assert tesseract_parser.archive_path is not None
@@ -868,7 +895,7 @@ class TestSkipArchive:
- Text is extracted from the original via pdftotext
- No archive is produced
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
mock_ocr = mocker.patch("ocrmypdf.ocr")
tesseract_parser.parse(
tesseract_samples_dir / "simple-digital.pdf",
@@ -895,7 +922,7 @@ class TestSkipArchive:
- Archive is produced
- Text is preserved from the original
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "simple-digital.pdf",
"application/pdf",
@@ -925,15 +952,17 @@ class TestParseMixed:
THEN:
- All pages extracted; archive created; sidecar notes skipped pages
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-mixed.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 1", "page 2", "page 3", "page 4", "page 5", "page 6"],
)
sidecar = (tesseract_parser.tempdir / "sidecar.txt").read_text()
@@ -953,15 +982,17 @@ class TestParseMixed:
THEN:
- Both text layer and image text extracted; archive created
"""
tesseract_parser.settings.mode = "redo"
tesseract_parser.settings.mode = ModeChoices.REDO
tesseract_parser.parse(
tesseract_samples_dir / "single-page-mixed.pdf",
"application/pdf",
)
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
[
"this is some normal text, present on page 1 of the document.",
"this is some text, but in an image, also on page 1.",
@@ -989,15 +1020,17 @@ class TestParseMixed:
THEN:
- No archive created (produce_archive=False); text from text layer present
"""
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.parse(
tesseract_samples_dir / "multi-page-mixed.pdf",
"application/pdf",
produce_archive=False,
)
assert tesseract_parser.archive_path is None
_text = tesseract_parser.get_text()
assert _text is not None
assert_ordered_substrings(
tesseract_parser.get_text().lower(),
_text.lower(),
["page 4", "page 5", "page 6"],
)
@@ -1013,7 +1046,7 @@ class TestParseRotate:
tesseract_parser: RasterisedDocumentParser,
tesseract_samples_dir: Path,
) -> None:
tesseract_parser.settings.mode = "auto"
tesseract_parser.settings.mode = ModeChoices.AUTO
tesseract_parser.settings.rotate = True
tesseract_parser.parse(tesseract_samples_dir / "rotated.pdf", "application/pdf")
assert_ordered_substrings(
@@ -1052,14 +1085,16 @@ class TestParseRtl:
force-ocr with English Tesseract (producing garbage). Using mode="off" forces
skip_text=True so the Arabic text layer is preserved through PDF/A conversion.
"""
tesseract_parser.settings.mode = "off"
tesseract_parser.settings.mode = ModeChoices.OFF
tesseract_parser.parse(
tesseract_samples_dir / "rtl-test.pdf",
"application/pdf",
)
_text = tesseract_parser.get_text()
assert _text is not None
normalised = "".join(
ch
for ch in unicodedata.normalize("NFKC", tesseract_parser.get_text())
for ch in unicodedata.normalize("NFKC", _text)
if unicodedata.category(ch) != "Cf" and not ch.isspace()
)
assert "ةرازو" in normalised
@@ -1196,7 +1231,9 @@ class TestParserFileTypes:
tesseract_parser.parse(tesseract_samples_dir / filename, mime_type)
assert tesseract_parser.archive_path is not None
assert tesseract_parser.archive_path.is_file()
assert "this is a test document" in tesseract_parser.get_text().lower()
_text = tesseract_parser.get_text()
assert _text is not None
assert "this is a test document" in _text.lower()
def test_heic(
self,
@@ -1205,7 +1242,9 @@ class TestParserFileTypes:
) -> None:
tesseract_parser.parse(tesseract_samples_dir / "simple.heic", "image/heic")
assert tesseract_parser.archive_path is not None
assert "pizza" in tesseract_parser.get_text().lower()
_text = tesseract_parser.get_text()
assert _text is not None
assert "pizza" in _text.lower()
def test_gif_with_explicit_dpi(
self,
@@ -1215,7 +1254,9 @@ class TestParserFileTypes:
tesseract_parser.settings.image_dpi = 200
tesseract_parser.parse(tesseract_samples_dir / "simple.gif", "image/gif")
assert tesseract_parser.archive_path is not None
assert "this is a test document" in tesseract_parser.get_text().lower()
_text = tesseract_parser.get_text()
assert _text is not None
assert "this is a test document" in _text.lower()
def test_webp_with_explicit_dpi(
self,
@@ -1225,9 +1266,11 @@ class TestParserFileTypes:
tesseract_parser.settings.image_dpi = 72
tesseract_parser.parse(tesseract_samples_dir / "document.webp", "image/webp")
assert tesseract_parser.archive_path is not None
_text = tesseract_parser.get_text()
assert _text is not None
assert re.search(
r"this is a ?webp document, created 11/14/2022\.",
tesseract_parser.get_text().lower(),
_text.lower(),
)
@@ -26,7 +26,7 @@ class TestStringToBool:
pytest.param(" True ", id="whitespace_true"),
],
)
def test_true_conversion(self, true_value: str):
def test_true_conversion(self, true_value: str) -> None:
"""Test that various 'true' strings correctly evaluate to True."""
assert str_to_bool(true_value) is True
@@ -41,18 +41,18 @@ class TestStringToBool:
pytest.param(" False ", id="whitespace_false"),
],
)
def test_false_conversion(self, false_value: str):
def test_false_conversion(self, false_value: str) -> None:
"""Test that various 'false' strings correctly evaluate to False."""
assert str_to_bool(false_value) is False
def test_invalid_conversion(self):
def test_invalid_conversion(self) -> None:
"""Test that an invalid string raises a ValueError."""
with pytest.raises(ValueError, match="Cannot convert 'maybe' to a boolean\\."):
str_to_bool("maybe")
class TestParseDictFromString:
def test_empty_and_none_input(self):
def test_empty_and_none_input(self) -> None:
"""Test behavior with None or empty string input."""
assert parse_dict_from_str(None) == {}
assert parse_dict_from_str("") == {}
@@ -62,13 +62,13 @@ class TestParseDictFromString:
# Ensure it returns a copy, not the original object
assert res is not defaults
def test_basic_parsing(self):
def test_basic_parsing(self) -> None:
"""Test simple key-value parsing without defaults or types."""
env_str = "key1=val1, key2=val2"
expected = {"key1": "val1", "key2": "val2"}
assert parse_dict_from_str(env_str) == expected
def test_with_defaults(self):
def test_with_defaults(self) -> None:
"""Test that environment values override defaults correctly."""
defaults = {"host": "localhost", "port": 8000, "user": "default"}
env_str = "port=9090, host=db.example.com"
@@ -76,7 +76,7 @@ class TestParseDictFromString:
result = parse_dict_from_str(env_str, defaults=defaults)
assert result == expected
def test_type_casting(self):
def test_type_casting(self) -> None:
"""Test successful casting of values to specified types."""
env_str = "port=9090, debug=true, timeout=12.5, user=admin"
type_map = {"port": int, "debug": bool, "timeout": float}
@@ -84,7 +84,7 @@ class TestParseDictFromString:
result = parse_dict_from_str(env_str, type_map=type_map)
assert result == expected
def test_type_casting_with_defaults(self):
def test_type_casting_with_defaults(self) -> None:
"""Test casting when values come from both defaults and env string."""
defaults = {"port": 8000, "debug": False, "retries": 3}
env_str = "port=9090, debug=true"
@@ -97,7 +97,7 @@ class TestParseDictFromString:
assert result == expected
assert isinstance(result["retries"], int)
def test_path_casting(self, tmp_path: Path):
def test_path_casting(self, tmp_path: Path) -> None:
"""Test successful casting of a string to a resolved pathlib.Path object."""
# Create a dummy file to resolve against
test_file = tmp_path / "test_file.txt"
@@ -111,14 +111,14 @@ class TestParseDictFromString:
assert isinstance(result["config_path"], Path)
assert result["config_path"] == test_file.resolve()
def test_custom_separator(self):
def test_custom_separator(self) -> None:
"""Test parsing with a custom separator like a semicolon."""
env_str = "host=db; port=5432; user=test"
expected = {"host": "db", "port": "5432", "user": "test"}
result = parse_dict_from_str(env_str, separator=";")
assert result == expected
def test_edge_cases_in_string(self):
def test_edge_cases_in_string(self) -> None:
"""Test malformed strings to ensure robustness."""
# Malformed pair 'debug' is skipped, extra comma is ignored
env_str = "key=val,, debug, foo=bar"
@@ -130,7 +130,7 @@ class TestParseDictFromString:
expected = {"url": "postgres://user:pass@host:5432/db"}
assert parse_dict_from_str(env_str) == expected
def test_casting_error_handling(self):
def test_casting_error_handling(self) -> None:
"""Test that a ValueError is raised for invalid casting."""
env_str = "port=not-a-number"
type_map = {"port": int}
@@ -142,14 +142,14 @@ class TestParseDictFromString:
assert "value 'not-a-number'" in str(excinfo.value)
assert "to type 'int'" in str(excinfo.value)
def test_bool_casting_error(self):
def test_bool_casting_error(self) -> None:
"""Test that an invalid boolean string raises a ValueError."""
env_str = "debug=maybe"
type_map = {"debug": bool}
with pytest.raises(ValueError, match="Error casting key 'debug'"):
parse_dict_from_str(env_str, type_map=type_map)
def test_nested_key_parsing_basic(self):
def test_nested_key_parsing_basic(self) -> None:
"""Basic nested key parsing using dot-notation."""
env_str = "database.host=db.example.com, database.port=5432, logging.level=INFO"
result = parse_dict_from_str(env_str)
@@ -158,7 +158,7 @@ class TestParseDictFromString:
"logging": {"level": "INFO"},
}
def test_nested_overrides_defaults_and_deepcopy(self):
def test_nested_overrides_defaults_and_deepcopy(self) -> None:
"""Nested env keys override defaults and defaults are deep-copied."""
defaults = {"database": {"host": "127.0.0.1", "port": 3306, "user": "default"}}
env_str = "database.host=db.example.com, debug=true"
@@ -176,7 +176,7 @@ class TestParseDictFromString:
assert result is not defaults
assert result["database"] is not defaults["database"]
def test_nested_type_casting(self):
def test_nested_type_casting(self) -> None:
"""Type casting for nested keys (dot-notation) should work."""
env_str = "database.host=db.example.com, database.port=5433, debug=false"
type_map = {"database.port": int, "debug": bool}
@@ -188,7 +188,7 @@ class TestParseDictFromString:
assert result["debug"] is False
assert isinstance(result["debug"], bool)
def test_nested_casting_error_message(self):
def test_nested_casting_error_message(self) -> None:
"""Error messages should include the full dotted key name on failure."""
env_str = "database.port=not-a-number"
type_map = {"database.port": int}
@@ -200,7 +200,7 @@ class TestParseDictFromString:
assert "value 'not-a-number'" in msg
assert "to type 'int'" in msg
def test_type_map_does_not_recast_non_string_defaults(self):
def test_type_map_does_not_recast_non_string_defaults(self) -> None:
"""If a default already provides a non-string value, the caster should skip it."""
defaults = {"database": {"port": 3306}}
type_map = {"database.port": int}
@@ -210,22 +210,22 @@ class TestParseDictFromString:
class TestGetBoolFromEnv:
def test_existing_env_var(self, mocker):
def test_existing_env_var(self, mocker) -> None:
"""Test that an existing environment variable is read and converted."""
mocker.patch.dict(os.environ, {"TEST_VAR": "true"})
assert get_bool_from_env("TEST_VAR") is True
def test_missing_env_var_uses_default_no(self, mocker):
def test_missing_env_var_uses_default_no(self, mocker) -> None:
"""Test that a missing environment variable uses default 'NO' and returns False."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_bool_from_env("MISSING_VAR") is False
def test_missing_env_var_with_explicit_default(self, mocker):
def test_missing_env_var_with_explicit_default(self, mocker) -> None:
"""Test that a missing environment variable uses the provided default."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_bool_from_env("MISSING_VAR", default="yes") is True
def test_invalid_value_raises_error(self, mocker):
def test_invalid_value_raises_error(self, mocker) -> None:
"""Test that an invalid value raises ValueError (delegates to str_to_bool)."""
mocker.patch.dict(os.environ, {"INVALID_VAR": "maybe"})
with pytest.raises(ValueError):
@@ -243,7 +243,7 @@ class TestGetIntFromEnv:
pytest.param("-999", -999, id="large_negative"),
],
)
def test_existing_env_var_valid_ints(self, mocker, env_value, expected):
def test_existing_env_var_valid_ints(self, mocker, env_value, expected) -> None:
"""Test that existing environment variables with valid integers return correct values."""
mocker.patch.dict(os.environ, {"INT_VAR": env_value})
assert get_int_from_env("INT_VAR") == expected
@@ -257,12 +257,12 @@ class TestGetIntFromEnv:
pytest.param(None, None, id="none_default"),
],
)
def test_missing_env_var_with_defaults(self, mocker, default, expected):
def test_missing_env_var_with_defaults(self, mocker, default, expected) -> None:
"""Test that missing environment variables return provided defaults."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_int_from_env("MISSING_VAR", default=default) == expected
def test_missing_env_var_no_default(self, mocker):
def test_missing_env_var_no_default(self, mocker) -> None:
"""Test that missing environment variable with no default returns None."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_int_from_env("MISSING_VAR") is None
@@ -279,7 +279,7 @@ class TestGetIntFromEnv:
pytest.param("1.0", id="decimal"),
],
)
def test_invalid_int_values_raise_error(self, mocker, invalid_value):
def test_invalid_int_values_raise_error(self, mocker, invalid_value) -> None:
"""Test that invalid integer values raise ValueError."""
mocker.patch.dict(os.environ, {"INVALID_INT": invalid_value})
with pytest.raises(ValueError):
@@ -300,7 +300,7 @@ class TestGetFloatFromEnv:
pytest.param("-1.23e4", -12300.0, id="sci_large"),
],
)
def test_existing_env_var_valid_floats(self, mocker, env_value, expected):
def test_existing_env_var_valid_floats(self, mocker, env_value, expected) -> None:
"""Test that existing environment variables with valid floats return correct values."""
mocker.patch.dict(os.environ, {"FLOAT_VAR": env_value})
assert get_float_from_env("FLOAT_VAR") == expected
@@ -314,12 +314,12 @@ class TestGetFloatFromEnv:
pytest.param(None, None, id="none_default"),
],
)
def test_missing_env_var_with_defaults(self, mocker, default, expected):
def test_missing_env_var_with_defaults(self, mocker, default, expected) -> None:
"""Test that missing environment variables return provided defaults."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_float_from_env("MISSING_VAR", default=default) == expected
def test_missing_env_var_no_default(self, mocker):
def test_missing_env_var_no_default(self, mocker) -> None:
"""Test that missing environment variable with no default returns None."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_float_from_env("MISSING_VAR") is None
@@ -336,7 +336,7 @@ class TestGetFloatFromEnv:
pytest.param("1.2.3", id="triple_decimal"),
],
)
def test_invalid_float_values_raise_error(self, mocker, invalid_value):
def test_invalid_float_values_raise_error(self, mocker, invalid_value) -> None:
"""Test that invalid float values raise ValueError."""
mocker.patch.dict(os.environ, {"INVALID_FLOAT": invalid_value})
with pytest.raises(ValueError):
@@ -355,19 +355,19 @@ class TestGetPathFromEnv:
pytest.param("/", id="root"),
],
)
def test_existing_env_var_paths(self, mocker, env_value):
def test_existing_env_var_paths(self, mocker, env_value) -> None:
"""Test that existing environment variables with paths return resolved Path objects."""
mocker.patch.dict(os.environ, {"PATH_VAR": env_value})
result = get_path_from_env("PATH_VAR")
assert isinstance(result, Path)
assert result == Path(env_value).resolve()
def test_missing_env_var_no_default(self, mocker):
def test_missing_env_var_no_default(self, mocker) -> None:
"""Test that missing environment variable with no default returns None."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_path_from_env("MISSING_VAR") is None
def test_missing_env_var_with_none_default(self, mocker):
def test_missing_env_var_with_none_default(self, mocker) -> None:
"""Test that missing environment variable with None default returns None."""
mocker.patch.dict(os.environ, {}, clear=True)
assert get_path_from_env("MISSING_VAR", default=None) is None
@@ -380,7 +380,7 @@ class TestGetPathFromEnv:
pytest.param(".", id="current_default"),
],
)
def test_missing_env_var_with_path_defaults(self, mocker, default_path_str):
def test_missing_env_var_with_path_defaults(self, mocker, default_path_str) -> None:
"""Test that missing environment variables return resolved default Path objects."""
mocker.patch.dict(os.environ, {}, clear=True)
default_path = Path(default_path_str)
@@ -388,7 +388,7 @@ class TestGetPathFromEnv:
assert isinstance(result, Path)
assert result == default_path.resolve()
def test_relative_paths_are_resolved(self, mocker):
def test_relative_paths_are_resolved(self, mocker) -> None:
"""Test that relative paths are properly resolved to absolute paths."""
mocker.patch.dict(os.environ, {"REL_PATH": "relative/path"})
result = get_path_from_env("REL_PATH")
@@ -407,7 +407,7 @@ class TestGetListFromEnv:
pytest.param("a,,b,c", ["a", "b", "c"], id="empty_elements_removed"),
],
)
def test_existing_env_var_basic_parsing(self, mocker, env_value, expected):
def test_existing_env_var_basic_parsing(self, mocker, env_value, expected) -> None:
"""Test that existing environment variables are parsed correctly."""
mocker.patch.dict(os.environ, {"LIST_VAR": env_value})
result = get_list_from_env("LIST_VAR")
@@ -421,7 +421,7 @@ class TestGetListFromEnv:
pytest.param(";", "a;b;c", ["a", "b", "c"], id="semicolon_separator"),
],
)
def test_custom_separators(self, mocker, separator, env_value, expected):
def test_custom_separators(self, mocker, separator, env_value, expected) -> None:
"""Test that custom separators work correctly."""
mocker.patch.dict(os.environ, {"LIST_VAR": env_value})
result = get_list_from_env("LIST_VAR", separator=separator)
@@ -439,19 +439,19 @@ class TestGetListFromEnv:
pytest.param(None, [], id="none_default_returns_empty_list"),
],
)
def test_missing_env_var_with_defaults(self, mocker, default, expected):
def test_missing_env_var_with_defaults(self, mocker, default, expected) -> None:
"""Test that missing environment variables return provided defaults."""
mocker.patch.dict(os.environ, {}, clear=True)
result = get_list_from_env("MISSING_VAR", default=default)
assert result == expected
def test_missing_env_var_no_default(self, mocker):
def test_missing_env_var_no_default(self, mocker) -> None:
"""Test that missing environment variable with no default returns empty list."""
mocker.patch.dict(os.environ, {}, clear=True)
result = get_list_from_env("MISSING_VAR")
assert result == []
def test_required_env_var_missing_raises_error(self, mocker):
def test_required_env_var_missing_raises_error(self, mocker) -> None:
"""Test that missing required environment variable raises ValueError."""
mocker.patch.dict(os.environ, {}, clear=True)
with pytest.raises(
@@ -460,19 +460,19 @@ class TestGetListFromEnv:
):
get_list_from_env("REQUIRED_VAR", required=True)
def test_required_env_var_with_default_does_not_raise(self, mocker):
def test_required_env_var_with_default_does_not_raise(self, mocker) -> None:
"""Test that required environment variable with default does not raise error."""
mocker.patch.dict(os.environ, {}, clear=True)
result = get_list_from_env("REQUIRED_VAR", default=["default"], required=True)
assert result == ["default"]
def test_strip_whitespace_false(self, mocker):
def test_strip_whitespace_false(self, mocker) -> None:
"""Test that whitespace is preserved when strip_whitespace=False."""
mocker.patch.dict(os.environ, {"LIST_VAR": " a , b , c "})
result = get_list_from_env("LIST_VAR", strip_whitespace=False)
assert result == [" a ", " b ", " c "]
def test_remove_empty_false(self, mocker):
def test_remove_empty_false(self, mocker) -> None:
"""Test that empty elements are preserved when remove_empty=False."""
mocker.patch.dict(os.environ, {"LIST_VAR": "a,,b,,c"})
result = get_list_from_env("LIST_VAR", remove_empty=False)
+9 -9
View File
@@ -1,6 +1,7 @@
import hmac
import pickle
from hashlib import sha256
from pathlib import Path
import pytest
from django.test import override_settings
@@ -11,21 +12,20 @@ from paperless.celery import signed_pickle_loads
class TestSignedPickleSerializer:
def test_roundtrip_simple_types(self):
def test_roundtrip_simple_types(self) -> None:
"""Signed pickle can round-trip basic JSON-like types."""
for obj in [42, "hello", [1, 2, 3], {"key": "value"}, None, True]:
assert signed_pickle_loads(signed_pickle_dumps(obj)) == obj
def test_roundtrip_complex_types(self):
def test_roundtrip_complex_types(self) -> None:
"""Signed pickle can round-trip types that JSON cannot."""
from pathlib import Path
obj = {"path": Path("/tmp/test"), "data": {1, 2, 3}}
result = signed_pickle_loads(signed_pickle_dumps(obj))
assert result["path"] == Path("/tmp/test")
assert result["data"] == {1, 2, 3}
def test_tampered_data_rejected(self):
def test_tampered_data_rejected(self) -> None:
"""Flipping a byte in the data portion causes HMAC failure."""
payload = signed_pickle_dumps({"task": "test"})
tampered = bytearray(payload)
@@ -33,7 +33,7 @@ class TestSignedPickleSerializer:
with pytest.raises(ValueError, match="HMAC verification failed"):
signed_pickle_loads(bytes(tampered))
def test_tampered_signature_rejected(self):
def test_tampered_signature_rejected(self) -> None:
"""Flipping a byte in the signature portion causes HMAC failure."""
payload = signed_pickle_dumps({"task": "test"})
tampered = bytearray(payload)
@@ -41,17 +41,17 @@ class TestSignedPickleSerializer:
with pytest.raises(ValueError, match="HMAC verification failed"):
signed_pickle_loads(bytes(tampered))
def test_truncated_payload_rejected(self):
def test_truncated_payload_rejected(self) -> None:
"""A payload shorter than HMAC_SIZE is rejected."""
with pytest.raises(ValueError, match="too short"):
signed_pickle_loads(b"\x00" * (HMAC_SIZE - 1))
def test_empty_payload_rejected(self):
def test_empty_payload_rejected(self) -> None:
with pytest.raises(ValueError, match="too short"):
signed_pickle_loads(b"")
@override_settings(SECRET_KEY="different-secret-key")
def test_wrong_secret_key_rejected(self):
def test_wrong_secret_key_rejected(self) -> None:
"""A message signed with one key cannot be loaded with another."""
original_key = b"test-secret-key-do-not-use-in-production"
obj = {"task": "test"}
@@ -61,7 +61,7 @@ class TestSignedPickleSerializer:
with pytest.raises(ValueError, match="HMAC verification failed"):
signed_pickle_loads(payload)
def test_forged_pickle_rejected(self):
def test_forged_pickle_rejected(self) -> None:
"""A raw pickle payload (no signature) is rejected."""
raw_pickle = pickle.dumps({"task": "test"})
# Raw pickle won't have a valid HMAC prefix
@@ -45,45 +45,45 @@ class TestMigrateSkipArchiveFile(TestMigrations):
)
return ApplicationConfiguration.objects.get(pk=pk)
def test_skip_mapped_to_auto(self):
def test_skip_mapped_to_auto(self) -> None:
config = self._get_config(1)
assert config.mode == "auto"
def test_skip_archive_always_mapped_to_never(self):
def test_skip_archive_always_mapped_to_never(self) -> None:
config = self._get_config(1)
assert config.archive_file_generation == "never"
def test_redo_unchanged(self):
def test_redo_unchanged(self) -> None:
config = self._get_config(2)
assert config.mode == "redo"
def test_skip_archive_with_text_mapped_to_auto(self):
def test_skip_archive_with_text_mapped_to_auto(self) -> None:
config = self._get_config(2)
assert config.archive_file_generation == "auto"
def test_force_unchanged(self):
def test_force_unchanged(self) -> None:
config = self._get_config(3)
assert config.mode == "force"
def test_skip_archive_never_mapped_to_always(self):
def test_skip_archive_never_mapped_to_always(self) -> None:
config = self._get_config(3)
assert config.archive_file_generation == "always"
def test_skip_noarchive_mapped_to_auto(self):
def test_skip_noarchive_mapped_to_auto(self) -> None:
config = self._get_config(4)
assert config.mode == "auto"
def test_skip_noarchive_implies_archive_never(self):
def test_skip_noarchive_implies_archive_never(self) -> None:
config = self._get_config(4)
assert config.archive_file_generation == "never"
def test_skip_noarchive_explicit_skip_archive_takes_precedence(self):
def test_skip_noarchive_explicit_skip_archive_takes_precedence(self) -> None:
"""skip_archive_file=never maps to always, not overridden by skip_noarchive."""
config = self._get_config(5)
assert config.mode == "auto"
assert config.archive_file_generation == "always"
def test_null_values_remain_null(self):
def test_null_values_remain_null(self) -> None:
config = self._get_config(6)
assert config.mode is None
assert config.archive_file_generation is None
+3 -3
View File
@@ -9,7 +9,7 @@ from paperless_mail.models import MailRule
from paperless_mail.models import ProcessedMail
class MailAccountFactory(DjangoModelFactory):
class MailAccountFactory(DjangoModelFactory[MailAccount]):
class Meta:
model = MailAccount
@@ -24,7 +24,7 @@ class MailAccountFactory(DjangoModelFactory):
is_token = False
class MailRuleFactory(DjangoModelFactory):
class MailRuleFactory(DjangoModelFactory[MailRule]):
class Meta:
model = MailRule
@@ -44,7 +44,7 @@ class MailRuleFactory(DjangoModelFactory):
stop_processing = False
class ProcessedMailFactory(DjangoModelFactory):
class ProcessedMailFactory(DjangoModelFactory[ProcessedMail]):
class Meta:
model = ProcessedMail
+5 -4
View File
@@ -1592,7 +1592,7 @@ class TestPostConsumeAction(TestCase):
mock_get_rule_action,
mock_mailbox_login,
mock_get_mailbox,
):
) -> None:
mock_mailbox = mock.MagicMock()
mock_get_mailbox.return_value.__enter__.return_value = mock_mailbox
mock_action = mock.MagicMock()
@@ -1625,7 +1625,7 @@ class TestPostConsumeAction(TestCase):
mock_get_rule_action,
mock_mailbox_login,
mock_get_mailbox,
):
) -> None:
mock_mailbox = mock.MagicMock()
mock_get_mailbox.return_value.__enter__.return_value = mock_mailbox
mock_action = mock.MagicMock()
@@ -1762,7 +1762,7 @@ class TestTasks(TestCase):
self.assertIn("No new", result)
@mock.patch("paperless_mail.tasks.MailAccountHandler.handle_mail_account")
def test_rule_with_stop_processing(self, m):
def test_rule_with_stop_processing(self, m) -> None:
"""
GIVEN:
- Mail account with a rule with stop_processing=True
@@ -1863,7 +1863,7 @@ class TestMailAccountTestView(APITestCase):
def test_mail_account_test_view_refresh_token(
self,
mock_refresh_account_oauth_token,
):
) -> None:
"""
GIVEN:
- Mail account with expired token
@@ -2050,6 +2050,7 @@ class TestMailRuleAPI(APITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(MailRule.objects.count(), 1)
rule = MailRule.objects.first()
assert rule is not None
self.assertEqual(rule.name, "Test Rule")
def test_mail_rule_action_parameter_required_for_tag_or_move(self) -> None:
+3 -3
View File
@@ -96,7 +96,7 @@ class TestMailOAuth(
self,
mock_get_outlook_access_token,
mock_get_gmail_access_token,
):
) -> None:
"""
GIVEN:
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
@@ -277,7 +277,7 @@ class TestMailOAuth(
self,
mock_refresh_token,
mock_get_mailbox,
):
) -> None:
"""
GIVEN:
- Mail account with refresh token and expiration
@@ -334,7 +334,7 @@ class TestMailOAuth(
self,
mock_refresh_token,
mock_get_mailbox,
):
) -> None:
"""
GIVEN:
- Mail account with refresh token and expiration