feat: accelerate RSS parsing with Rust

This commit is contained in:
jxxghp
2026-05-22 21:31:18 +08:00
parent 052e1ca8e4
commit 4de4044a3e
15 changed files with 467 additions and 102 deletions

View File

@@ -56,7 +56,7 @@ MCP工具API文档详见 [docs/mcp-api.md](docs/mcp-api.md)
开发环境准备与本地源码运行说明:[`docs/development-setup.md`](docs/development-setup.md)
本地开发启用 Rust 加速扩展,需先安装 Rust toolchain 并确保 `cargo` 可用:
本地开发启用 Rust 加速扩展,需先安装 Rust toolchain 并确保 `cargo` 可用;未安装时项目会自动使用 Python 实现
```shell
cargo --version
@@ -67,6 +67,12 @@ python -c "from app.utils import rust_accel; print(rust_accel.is_available())"
如果输出 `True`,说明当前开发环境已经加载 `moviepilot_rust`。重新修改 Rust 代码后再次执行 `python -m maturin develop --release --manifest-path rust/moviepilot_rust/Cargo.toml` 即可更新本地扩展。
需要本地评估 Rust 加速效果时,可运行:
```shell
python scripts/benchmark_rust_accel.py --loops 20 --repeat 5
```
插件开发说明:<https://wiki.movie-pilot.org/zh/plugindev>
## 相关项目

View File

@@ -9,6 +9,7 @@ from lxml import etree
from app.core.config import settings
from app.helper.browser import PlaywrightHelper
from app.log import logger
from app.utils import rust_accel
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
@@ -227,6 +228,32 @@ class RssHelper:
},
}
@staticmethod
def __format_rust_items(items: List[dict]) -> List[dict]:
"""
将 Rust RSS 解析结果转换为原 Python XPath 解析返回结构。
"""
ret_array = []
for item in items:
pubdate = ""
pubdate_raw = item.get("pubdate_raw")
if pubdate_raw:
pubdate = StringUtils.get_time(pubdate_raw)
if pubdate is not None:
pubdate = pubdate.astimezone(tz=None)
tmp_dict = {
'title': item.get("title") or "",
'enclosure': item.get("enclosure") or "",
'size': item.get("size") or 0,
'description': item.get("description") or "",
'link': item.get("link") or "",
'pubdate': pubdate
}
if item.get("nickname"):
tmp_dict['nickname'] = item.get("nickname")
ret_array.append(tmp_dict)
return ret_array
def parse(self, url, proxy: bool = False,
timeout: Optional[int] = 15, headers: dict = None, ua: str = None) -> Union[List[dict], None, bool]:
"""
@@ -298,6 +325,12 @@ class RssHelper:
logger.error("RSS内容不是有效的XML格式")
return False
rust_items = rust_accel.parse_rss_items(ret_xml, self.MAX_RSS_ITEMS)
if rust_items is not None:
if len(rust_items) >= self.MAX_RSS_ITEMS:
logger.warning(f"RSS条目过多仅处理前{self.MAX_RSS_ITEMS}")
return self.__format_rust_items(rust_items)
# 使用lxml.etree解析XML
parser = None
try:

View File

@@ -11,7 +11,6 @@ from requests import Session
from app.core.config import settings
from app.helper.cloudflare import under_challenge
from app.log import logger
from app.utils import rust_accel
from app.utils.http import RequestUtils
from app.utils.site import SiteUtils
from app.utils.string import StringUtils
@@ -159,11 +158,8 @@ class SiteParserBase(metaclass=ABCMeta):
@staticmethod
def num_filesize(text) -> int:
"""
将站点页面中的文件大小文本转换为字节,优先使用 Rust 快路径
将站点页面中的文件大小文本转换为字节。
"""
rust_value = rust_accel.parse_filesize(text)
if rust_value is not None:
return rust_value
return StringUtils.num_filesize(text)
def parse(self):

View File

@@ -672,9 +672,6 @@ class SiteSpider:
"""
if not text or not filters or not isinstance(filters, list):
return text
rust_text = rust_accel.apply_indexer_text_filters(text, filters)
if rust_text is not None:
return rust_text
if not isinstance(text, str):
text = str(text)
for filter_item in filters:

View File

@@ -117,34 +117,6 @@ def filter_torrents(
return None
def apply_indexer_text_filters(text: Any, filters: Optional[List[dict]]) -> Optional[str]:
"""
使用 Rust 执行 indexer 文本过滤器,不可用或遇到不支持过滤器时返回 None。
"""
if not _moviepilot_rust or not filters or not isinstance(filters, list):
return None
try:
return _moviepilot_rust.apply_indexer_text_filters_fast(None if text is None else str(text), filters)
except BaseException as err:
_raise_non_rust_panic(err)
logger.debug(f"Rust 站点文本过滤失败,回退 Python{err}")
return None
def parse_filesize(text: Any) -> Optional[int]:
"""
使用 Rust 将文件大小文本转换为字节,不可用时返回 None。
"""
if not _moviepilot_rust:
return None
try:
return int(_moviepilot_rust.parse_filesize_fast(text))
except BaseException as err:
_raise_non_rust_panic(err)
logger.debug(f"Rust 文件大小解析失败,回退 Python{err}")
return None
def build_indexer_search_url(config: dict) -> Optional[str]:
"""
使用 Rust 根据普通 indexer 配置生成搜索 URL不可用时返回 None。
@@ -187,6 +159,20 @@ def parse_indexer_torrents(
return None
def parse_rss_items(xml_text: str, max_items: int) -> Optional[List[dict]]:
"""
使用 Rust 批量解析 RSS/Atom 条目,不可用或解析失败时返回 None。
"""
if not _moviepilot_rust:
return None
try:
return _moviepilot_rust.parse_rss_items_fast(xml_text or "", int(max_items or 0))
except BaseException as err:
_raise_non_rust_panic(err)
logger.debug(f"Rust RSS 条目解析失败,回退 Python{err}")
return None
def _coerce_media_type(value: Optional[str]) -> Optional[MediaType]:
"""
将 Rust 返回的媒体类型字符串转换为系统 MediaType。

View File

@@ -360,6 +360,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"pyo3",
"quick-xml",
"regex",
"scraper",
"url",
@@ -551,6 +552,15 @@ dependencies = [
"syn",
]
[[package]]
name = "quick-xml"
version = "0.38.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c"
dependencies = [
"memchr",
]
[[package]]
name = "quote"
version = "1.0.45"

View File

@@ -11,6 +11,7 @@ crate-type = ["cdylib"]
once_cell = "1.20"
percent-encoding = "2.3"
pyo3 = { version = "0.23", features = ["abi3-py311", "extension-module"] }
quick-xml = "0.38"
regex = "1.11"
scraper = "0.24"
url = "2.5"

View File

@@ -3,7 +3,7 @@ use once_cell::sync::Lazy;
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyList};
use pyo3::types::{PyDict, PyList};
use regex::{Regex, RegexBuilder};
use scraper::{ElementRef, Html, Selector};
use url::form_urlencoded;
@@ -41,17 +41,6 @@ enum RowParseResult {
Item(PyObject),
}
#[pyfunction]
pub(crate) fn apply_indexer_text_filters_fast(
text: &Bound<'_, PyAny>,
filters: &Bound<'_, PyAny>,
) -> PyResult<Option<String>> {
if text.is_none() {
return Ok(None);
}
apply_text_filters(text.str()?.to_str()?.to_string(), filters)
}
/// 批量解析普通配置 indexer 页面,遇到不支持的选择器配置时返回 None 交给 Python 回退。
#[pyfunction]
#[pyo3(signature = (html_text, domain, list_config, fields, category=None, result_num=100))]
@@ -171,16 +160,7 @@ fn apply_text_filters(mut current: String, filters: &Bound<'_, PyAny>) -> PyResu
Ok(Some(current.trim().to_string()))
}
/// 将站点页面中的文件大小文本转换为字节数。
#[pyfunction]
pub(crate) fn parse_filesize_fast(text: &Bound<'_, PyAny>) -> PyResult<i64> {
if text.is_none() {
return Ok(0);
}
Ok(parse_filesize_text(text.str()?.to_str()?))
}
/// 将文件大小文本转换为字节数,供 Python 导出函数和 Rust HTML 解析共用。
/// 将文件大小文本转换为字节数,供 Rust HTML 解析内部共用
fn parse_filesize_text(text: &str) -> i64 {
let raw = text.trim().to_string();
if raw.is_empty() {

View File

@@ -1,6 +1,7 @@
mod filter;
mod indexer;
mod meta;
mod rss;
mod utils;
use pyo3::prelude::*;
@@ -20,12 +21,8 @@ fn moviepilot_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(meta::parse_video_title_fast, m)?)?;
m.add_function(wrap_pyfunction!(filter::parse_filter_rule_fast, m)?)?;
m.add_function(wrap_pyfunction!(filter::filter_torrents_fast, m)?)?;
m.add_function(wrap_pyfunction!(
indexer::apply_indexer_text_filters_fast,
m
)?)?;
m.add_function(wrap_pyfunction!(indexer::parse_filesize_fast, m)?)?;
m.add_function(wrap_pyfunction!(indexer::build_indexer_search_url_fast, m)?)?;
m.add_function(wrap_pyfunction!(indexer::parse_indexer_torrents_fast, m)?)?;
m.add_function(wrap_pyfunction!(rss::parse_rss_items_fast, m)?)?;
Ok(())
}

View File

@@ -0,0 +1,204 @@
use pyo3::prelude::*;
use pyo3::types::PyDict;
use quick_xml::events::{BytesStart, Event};
use quick_xml::name::QName;
use quick_xml::Reader;
#[derive(Default)]
struct RssItem {
title: String,
description: String,
link: String,
enclosure: String,
size: i64,
pubdate: String,
nickname: String,
}
impl RssItem {
fn has_output(&self) -> bool {
!self.title.is_empty() && (!self.enclosure.is_empty() || !self.link.is_empty())
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum TextField {
Title,
Description,
Link,
Pubdate,
Nickname,
}
/// 批量解析 RSS/Atom 条目,返回 Python 侧后续处理需要的核心字段。
#[pyfunction]
pub(crate) fn parse_rss_items_fast(
py: Python<'_>,
xml_text: &str,
max_items: usize,
) -> PyResult<Option<Vec<PyObject>>> {
let mut reader = Reader::from_str(xml_text);
reader.config_mut().trim_text(true);
let mut items = Vec::new();
let mut current_item: Option<RssItem> = None;
let mut current_field: Option<TextField> = None;
let mut item_depth = 0usize;
let mut parse_failed = false;
loop {
match reader.read_event() {
Ok(Event::Start(event)) => {
let local = local_name(event.name());
if current_item.is_none() && (local == "item" || local == "entry") {
current_item = Some(RssItem::default());
current_field = None;
item_depth = 1;
continue;
}
if let Some(item) = current_item.as_mut() {
item_depth += 1;
match local.as_str() {
"title" => current_field = Some(TextField::Title),
"description" | "summary" => current_field = Some(TextField::Description),
"pubDate" | "published" | "updated" => current_field = Some(TextField::Pubdate),
"creator" => current_field = Some(TextField::Nickname),
"link" => {
current_field = Some(TextField::Link);
if item.link.is_empty() {
if let Some(href) = attr_value(&event, QName(b"href")) {
item.link = href;
}
}
}
"enclosure" => {
if let Some(url) = attr_value(&event, QName(b"url")) {
item.enclosure = url;
}
if let Some(length) = attr_value(&event, QName(b"length")) {
item.size = length.parse::<i64>().unwrap_or(0);
}
}
_ => {}
}
}
}
Ok(Event::Empty(event)) => {
if let Some(item) = current_item.as_mut() {
match local_name(event.name()).as_str() {
"link" => {
if item.link.is_empty() {
if let Some(href) = attr_value(&event, QName(b"href")) {
item.link = href;
}
}
}
"enclosure" => {
if let Some(url) = attr_value(&event, QName(b"url")) {
item.enclosure = url;
}
if let Some(length) = attr_value(&event, QName(b"length")) {
item.size = length.parse::<i64>().unwrap_or(0);
}
}
_ => {}
}
}
}
Ok(Event::Text(event)) => {
if let (Some(item), Some(field)) = (current_item.as_mut(), current_field) {
if let Ok(value) = event.decode() {
append_field(item, field, value.as_ref());
}
}
}
Ok(Event::CData(event)) => {
if let (Some(item), Some(field)) = (current_item.as_mut(), current_field) {
if let Ok(value) = event.decode() {
append_field(item, field, value.as_ref());
}
}
}
Ok(Event::End(event)) => {
if current_item.is_some() {
let local = local_name(event.name());
if local == "item" || local == "entry" {
let mut item = current_item.take().unwrap_or_default();
if item.enclosure.is_empty() && !item.link.is_empty() {
item.enclosure = item.link.clone();
}
if item.has_output() {
items.push(item_to_py(py, &item)?.into_any().unbind());
if items.len() >= max_items {
break;
}
}
current_field = None;
item_depth = 0;
} else {
item_depth = item_depth.saturating_sub(1);
if item_depth <= 1 {
current_field = None;
}
}
}
}
Ok(Event::Eof) => break,
Err(_) => {
parse_failed = true;
break;
}
_ => {}
}
}
if parse_failed && items.is_empty() {
Ok(None)
} else {
Ok(Some(items))
}
}
/// 将内部 RSS 条目结构转换为 Python 字典。
fn item_to_py<'py>(py: Python<'py>, item: &RssItem) -> PyResult<Bound<'py, PyDict>> {
let dict = PyDict::new(py);
dict.set_item("title", item.title.trim())?;
dict.set_item("enclosure", item.enclosure.trim())?;
dict.set_item("size", item.size)?;
dict.set_item("description", item.description.trim())?;
dict.set_item("link", item.link.trim())?;
dict.set_item("pubdate_raw", item.pubdate.trim())?;
if !item.nickname.trim().is_empty() {
dict.set_item("nickname", item.nickname.trim())?;
}
Ok(dict)
}
/// 返回 XML 名称去掉命名空间前缀后的本地名称。
fn local_name(name: QName<'_>) -> String {
let raw = name.as_ref();
let local = raw.rsplit(|byte| *byte == b':').next().unwrap_or(raw);
std::str::from_utf8(local).unwrap_or("").to_string()
}
/// 读取 XML 节点属性并完成实体反转义。
fn attr_value(event: &BytesStart<'_>, name: QName<'_>) -> Option<String> {
event
.try_get_attribute(name)
.ok()
.flatten()
.and_then(|attr| attr.decode_and_unescape_value(event.decoder()).ok())
.map(|value| value.into_owned())
}
/// 追加当前文本节点到对应 RSS 字段。
fn append_field(item: &mut RssItem, field: TextField, value: &str) {
let target = match field {
TextField::Title => &mut item.title,
TextField::Description => &mut item.description,
TextField::Link => &mut item.link,
TextField::Pubdate => &mut item.pubdate,
TextField::Nickname => &mut item.nickname,
};
target.push_str(value);
}

View File

@@ -0,0 +1,114 @@
import argparse
import statistics
import sys
import time
from pathlib import Path
from typing import Callable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from lxml import etree
from app.utils import rust_accel
from app.utils.string import StringUtils
def _time_call(func: Callable[[], None], loops: int) -> float:
"""
执行指定函数多轮并返回总耗时。
"""
start = time.perf_counter()
for _ in range(loops):
func()
return time.perf_counter() - start
def _median_time(func: Callable[[], None], loops: int, repeat: int) -> float:
"""
重复测量指定函数并返回中位耗时。
"""
return statistics.median(_time_call(func, loops) for _ in range(repeat))
def _rss_xml(item_count: int = 200) -> str:
"""
生成稳定的 RSS 测试数据,避免网络和站点波动影响结果。
"""
items = []
for index in range(item_count):
items.append(
f"""
<item>
<title>Example Torrent {index}</title>
<description><![CDATA[Example Desc {index}]]></description>
<link>https://example.org/details/{index}</link>
<enclosure url="https://example.org/download/{index}.torrent" length="{index + 1024}" />
<pubDate>Tue, 19 May 2026 08:30:00 GMT</pubDate>
<dc:creator>User {index}</dc:creator>
</item>
"""
)
return "<rss xmlns:dc=\"http://purl.org/dc/elements/1.1/\"><channel>" + "".join(items) + "</channel></rss>"
def _python_rss_parse(xml_text: str) -> None:
"""
执行与 RssHelper 原 XPath 路径等价的 RSS 条目字段提取。
"""
root = etree.fromstring(
xml_text.encode("utf-8"),
parser=etree.XMLParser(recover=True, strip_cdata=False, resolve_entities=False, no_network=True),
)
parsed = []
for item in root.xpath(".//item | .//entry")[:1000]:
title_nodes = item.xpath(".//title")
title = title_nodes[0].text if title_nodes and title_nodes[0].text else ""
desc_nodes = item.xpath(".//description | .//summary")
description = desc_nodes[0].text if desc_nodes and desc_nodes[0].text else ""
link_nodes = item.xpath(".//link")
link = link_nodes[0].text if link_nodes and link_nodes[0].text else ""
enclosure_nodes = item.xpath(".//enclosure")
enclosure = enclosure_nodes[0].get("url", "") if enclosure_nodes else link
pubdate_nodes = item.xpath('./pubDate | ./published | ./updated')
pubdate = StringUtils.get_time(pubdate_nodes[0].text) if pubdate_nodes and pubdate_nodes[0].text else ""
parsed.append((title, description, link, enclosure, pubdate))
root.clear()
def _print_result(name: str, python_seconds: float, rust_seconds: float) -> None:
"""
输出单项基准耗时和提升倍数。
"""
speedup = python_seconds / rust_seconds if rust_seconds else 0
print(f"{name}: Python {python_seconds:.4f}s, Rust {rust_seconds:.4f}s, speedup {speedup:.2f}x")
def run_benchmark(loops: int, repeat: int) -> None:
"""
运行核心 Rust 加速模块的本地微基准。
"""
print(f"moviepilot_rust available: {rust_accel.is_available()}")
xml_text = _rss_xml()
_print_result(
"rss item parse",
_median_time(lambda: _python_rss_parse(xml_text), loops, repeat),
_median_time(lambda: rust_accel.parse_rss_items(xml_text, 1000), loops, repeat),
)
def main() -> None:
"""
命令行入口。
"""
parser = argparse.ArgumentParser(description="Benchmark MoviePilot Rust acceleration paths.")
parser.add_argument("--loops", type=int, default=20)
parser.add_argument("--repeat", type=int, default=5)
args = parser.parse_args()
run_benchmark(max(args.loops, 1), max(args.repeat, 1))
if __name__ == "__main__":
main()

View File

@@ -544,10 +544,17 @@ ensure_prereqs() {
exit 1
fi
if ! ensure_base_tools || ! ensure_build_tools || ! ensure_python || ! ensure_uv || ! ensure_rust_toolchain; then
if ! ensure_base_tools || ! ensure_python || ! ensure_uv; then
python_install_hint
exit 1
fi
if ! rust_accel_should_skip; then
if ! ensure_build_tools || ! ensure_rust_toolchain; then
export MOVIEPILOT_SKIP_RUST_ACCEL=1
echo "==> Rust 加速扩展准备失败,已跳过;应用将继续使用 Python 实现"
fi
fi
}
prompt_text() {

View File

@@ -656,27 +656,28 @@ def _find_native_linker() -> Optional[str]:
return None
def ensure_rust_accel_ready() -> None:
def ensure_rust_accel_ready() -> bool:
"""
确认 Rust 加速扩展源码存在且本机具备 cargo 与链接器。
"""
if not RUST_ACCEL_MANIFEST.exists():
return
return False
if _rust_accel_should_skip():
print_step(f"已跳过 Rust 加速扩展构建:{RUST_ACCEL_SKIP_ENV}=1")
return
return False
if not _find_cargo():
raise RuntimeError(
"未找到 Rust cargo无法构建 MoviePilot Rust 加速扩展"
"请先安装 Rust toolchain 后重试,或临时设置 "
f"{RUST_ACCEL_SKIP_ENV}=1 跳过加速扩展。"
print_step(
"未找到 Rust cargo已跳过 Rust 加速扩展构建;"
"应用将继续使用 Python 实现。"
)
return False
if not _find_native_linker():
raise RuntimeError(
"未找到本机 C 编译器/链接器,无法构建 MoviePilot Rust 加速扩展"
"请先安装系统构建工具后重试,或临时设置 "
f"{RUST_ACCEL_SKIP_ENV}=1 跳过加速扩展。"
print_step(
"未找到本机 C 编译器/链接器,已跳过 Rust 加速扩展构建;"
"应用将继续使用 Python 实现。"
)
return False
return True
def install_rust_accel(venv_python: Path) -> None:
@@ -688,22 +689,26 @@ def install_rust_accel(venv_python: Path) -> None:
if _rust_accel_should_skip():
return
ensure_rust_accel_ready()
if not ensure_rust_accel_ready():
return
print_step("构建并安装 Rust 加速扩展")
env = os.environ.copy()
env["PATH"] = _cargo_env_path()
run(
[
str(venv_python),
"-m",
"maturin",
"develop",
"--release",
"--manifest-path",
str(RUST_ACCEL_MANIFEST),
],
env=env,
)
try:
run(
[
str(venv_python),
"-m",
"maturin",
"develop",
"--release",
"--manifest-path",
str(RUST_ACCEL_MANIFEST),
],
env=env,
)
except subprocess.CalledProcessError as exc:
print_step(f"Rust 加速扩展构建失败,已跳过;应用将继续使用 Python 实现:{exc}")
def ensure_supported_python(python_bin: str) -> None:
@@ -2732,7 +2737,6 @@ def install_deps(*, python_bin: str, venv_dir: Path, recreate: bool) -> Path:
创建或复用本地虚拟环境并安装后端依赖、Rust 扩展和浏览器运行时。
"""
ensure_supported_python(python_bin)
ensure_rust_accel_ready()
venv_dir = venv_dir.expanduser().resolve()
venv_python = get_venv_python(venv_dir)
venv_pip = get_venv_pip(venv_dir)

View File

@@ -99,13 +99,37 @@ def test_rust_indexer_search_url_keeps_existing_query_and_category():
assert "search_field=imdb0049406" in search_url
def test_rust_filesize_parser_matches_site_units():
def test_rust_rss_parser_extracts_common_rss_and_atom_fields():
"""
Rust 文件大小解析应覆盖站点解析器常见单位
Rust RSS 解析应同时覆盖 RSS item 和 Atom entry 的核心字段
"""
assert rust_accel.parse_filesize("1.5 GB") == 1610612736
assert rust_accel.parse_filesize("2 TiB") == 2199023255552
assert rust_accel.parse_filesize("42") == 42
xml_text = """
<rss><channel>
<item>
<title>Example Torrent</title>
<description><![CDATA[Desc]]></description>
<link>https://example.org/details/1</link>
<enclosure url="https://example.org/download/1.torrent" length="1024" />
<pubDate>Tue, 19 May 2026 08:30:00 GMT</pubDate>
<dc:creator>豆瓣用户</dc:creator>
</item>
<entry>
<title>Atom Torrent</title>
<summary>Atom Desc</summary>
<link href="https://example.org/atom/2" />
<updated>2026-05-19T09:30:00Z</updated>
</entry>
</channel></rss>
"""
items = rust_accel.parse_rss_items(xml_text, 100)
assert items[0]["title"] == "Example Torrent"
assert items[0]["enclosure"] == "https://example.org/download/1.torrent"
assert items[0]["size"] == 1024
assert items[0]["nickname"] == "豆瓣用户"
assert items[1]["title"] == "Atom Torrent"
assert items[1]["enclosure"] == "https://example.org/atom/2"
def test_rust_indexer_page_parser_handles_common_fields():

View File

@@ -16,13 +16,13 @@ def _load_subscribe_chain_class():
module = sys.modules[module_name]
return module, module.SubscribeChain
injected_modules = {}
original_modules = {}
def ensure_module(name: str, module: types.ModuleType):
if name in sys.modules:
return sys.modules[name]
"""临时替换模块依赖,并记录原模块以便加载完成后恢复。"""
if name not in original_modules:
original_modules[name] = sys.modules.get(name)
sys.modules[name] = module
injected_modules[name] = module
return module
chain_module = ensure_module("app.chain", types.ModuleType("app.chain"))
@@ -270,9 +270,15 @@ def _load_subscribe_chain_class():
sys.modules[module_name] = module
assert spec and spec.loader
spec.loader.exec_module(module)
module._injected_modules = injected_modules
for injected_name in injected_modules:
sys.modules.pop(injected_name, None)
module._injected_modules = {
name: sys.modules.get(name)
for name in original_modules
}
for injected_name, original_module in original_modules.items():
if original_module is None:
sys.modules.pop(injected_name, None)
else:
sys.modules[injected_name] = original_module
return module, module.SubscribeChain