refactor: rely on transfer chain invariants

This commit is contained in:
jxxghp
2026-05-14 07:55:33 +08:00
parent f4423e121e
commit 0f3a4e4c15
2 changed files with 22 additions and 31 deletions

View File

@@ -799,10 +799,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
return StorageChain().is_bluray_folder(fileitem)
if not fileitem.extension:
return False
media_exts = (
self._media_exts if hasattr(self, "_media_exts") else settings.RMT_MEDIAEXT
)
return True if f".{fileitem.extension.lower()}" in media_exts else False
return True if f".{fileitem.extension.lower()}" in self._media_exts else False
def __is_allowed_file(self, fileitem: FileItem) -> bool:
"""
@@ -1179,8 +1176,6 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
if not task or not task.transfer_batch_id:
return
if not hasattr(self, "_scrape_batches"):
self._scrape_batches = {}
with job_lock:
batch = self._scrape_batches.setdefault(
task.transfer_batch_id,
@@ -1198,8 +1193,6 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
if not batch_id:
return
if not hasattr(self, "_scrape_batches"):
self._scrape_batches = {}
with job_lock:
batch = self._scrape_batches.setdefault(
batch_id,
@@ -1226,8 +1219,6 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
):
return
if not hasattr(self, "_scrape_batches"):
self._scrape_batches = {}
target_diritem = transferinfo.target_diritem
target_files = transferinfo.file_list_new or []
target_key = (target_diritem.storage, target_diritem.path)
@@ -1264,8 +1255,6 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
if not task or not task.transfer_batch_id:
return
if not hasattr(self, "_scrape_batches"):
self._scrape_batches = {}
with job_lock:
batch = self._scrape_batches.get(task.transfer_batch_id)
if not batch:
@@ -1279,8 +1268,6 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
"""
if not batch_id:
return
if not hasattr(self, "_scrape_batches"):
self._scrape_batches = {}
with job_lock:
batch = self._scrape_batches.get(batch_id)

View File

@@ -2,6 +2,7 @@ import unittest
from types import SimpleNamespace
from unittest.mock import patch, MagicMock
from app.core.config import settings
from app.chain.transfer import JobManager, TransferChain
from app.schemas import FileItem, TransferInfo, TransferTask
from app.schemas.types import EventType, MediaType
@@ -86,6 +87,20 @@ def make_task(episode: int, season: int = 1) -> TransferTask:
)
def make_transfer_chain() -> TransferChain:
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain._media_exts = settings.RMT_MEDIAEXT
chain._subtitle_exts = settings.RMT_SUBEXT
chain._audio_exts = settings.RMT_AUDIOEXT
chain._allowed_exts = (
chain._media_exts + chain._audio_exts + chain._subtitle_exts
)
chain._success_target_files = {}
chain._scrape_batches = {}
return chain
def migrate_to_media_job(jobview: JobManager, task: TransferTask):
task.mediainfo = FakeMedia()
jobview.migrate_task(task)
@@ -192,8 +207,7 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual(task2.fileitem, jobs[0].tasks[0].fileitem)
def test_exception_failure_does_not_mark_downloader_without_history(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain = make_transfer_chain()
completed = []
def fake_transfer_completed(hashs, downloader):
@@ -212,8 +226,7 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual([], chain.jobview.list_jobs())
def test_successful_history_skip_marks_downloader_hash_completed(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain = make_transfer_chain()
completed = []
def fake_transfer_completed(hashs, downloader):
@@ -255,8 +268,7 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual([("abc123", "qbittorrent")], completed)
def test_failed_history_skip_still_marks_downloader_hash_completed(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain = make_transfer_chain()
completed = []
def fake_transfer_completed(hashs, downloader):
@@ -298,10 +310,8 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual([("abc123", "qbittorrent")], completed)
def test_unrecognized_task_marks_downloader_hash_completed(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain = make_transfer_chain()
chain.post_message = lambda *_args, **_kwargs: None
chain._scrape_batches = {}
completed = []
def fake_transfer_completed(hashs, downloader):
@@ -336,10 +346,7 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual([], chain.jobview.list_jobs())
def test_scrape_event_is_aggregated_by_transfer_batch_across_seasons(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain._success_target_files = {}
chain._scrape_batches = {}
chain = make_transfer_chain()
chain.eventmanager = MagicMock()
chain.transfer_completed = lambda *args, **kwargs: None
@@ -430,10 +437,7 @@ class TransferJobManagerTest(unittest.TestCase):
self.assertEqual({}, chain._scrape_batches)
def test_scrape_event_keeps_immediate_behavior_without_transfer_batch(self):
chain = object.__new__(TransferChain)
chain.jobview = JobManager()
chain._success_target_files = {}
chain._scrape_batches = {}
chain = make_transfer_chain()
chain.eventmanager = MagicMock()
chain.transfer_completed = lambda *args, **kwargs: None