diff --git a/plugins.v2/clashruleprovider/__init__.py b/plugins.v2/clashruleprovider/__init__.py index 9c136fb..3b8e6de 100644 --- a/plugins.v2/clashruleprovider/__init__.py +++ b/plugins.v2/clashruleprovider/__init__.py @@ -1,35 +1,34 @@ -import json -import re -import urllib -from typing import Any, Optional, List, Dict, Tuple, Union -import time -from urllib.parse import urlparse -import yaml -import hashlib -from datetime import datetime, timedelta -import pytz -import copy -import math - -from apscheduler.schedulers.background import BackgroundScheduler -from apscheduler.triggers.cron import CronTrigger import asyncio -from fastapi import HTTPException, Request, status, Body, Response -import websockets -from sse_starlette.sse import EventSourceResponse +import copy +import hashlib +import json +import math +import re +import time +import urllib +from datetime import datetime, timedelta +from typing import Any, Optional, List, Dict, Tuple, Union +from urllib.parse import urlparse +import pytz +import websockets +import yaml from app import schemas from app.core.config import settings from app.core.event import eventmanager, Event -from app.schemas.types import EventType from app.log import logger -from app.schemas.types import NotificationType -from app.utils.ip import IpUtils -from app.utils.http import RequestUtils, AsyncRequestUtils from app.plugins import _PluginBase -from app.plugins.clashruleprovider.clashruleparser import ClashRuleParser, Converter from app.plugins.clashruleprovider.clashruleparser import Action, RuleType, ClashRule, MatchRule, LogicRule +from app.plugins.clashruleprovider.clashruleparser import ClashRuleParser, Converter from app.plugins.clashruleprovider.clashruleparser import ProxyGroup, RuleProvider +from app.schemas.types import EventType +from app.schemas.types import NotificationType +from app.utils.http import RequestUtils, AsyncRequestUtils +from app.utils.ip import IpUtils +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.cron import CronTrigger +from fastapi import HTTPException, Request, status, Body, Response +from sse_starlette.sse import EventSourceResponse class ClashRuleProvider(_PluginBase): @@ -171,12 +170,13 @@ class ClashRuleProvider(_PluginBase): # 清理不存在的 URL self._subscription_info = {url: self._subscription_info.get(url) for url in self._sub_links if self._subscription_info.get(url)} - self._clash_configs = {url: self._clash_configs[url] for url in self._sub_links if self._clash_configs.get(url)} + self._clash_configs = {url: self._clash_configs[url] for url in self._sub_links if + self._clash_configs.get(url)} self._scheduler = BackgroundScheduler(timezone=settings.TZ) self._scheduler.start() # 更新订阅 self._scheduler.add_job(self.refresh_subscriptions, "date", - run_date=datetime.now(tz=pytz.timezone(settings.TZ)) + timedelta(seconds=2)) + run_date=datetime.now(tz=pytz.timezone(settings.TZ)) + timedelta(seconds=2)) if self._hint_geo_dat: self._scheduler.add_job(self.__refresh_geo_dat, "date", run_date=datetime.now(tz=pytz.timezone(settings.TZ)) + timedelta(seconds=3)) @@ -189,7 +189,6 @@ class ClashRuleProvider(_PluginBase): else: self._acl4ssr_providers = {} - def get_state(self) -> bool: return self._enabled @@ -581,6 +580,7 @@ class ClashRuleProvider(_PluginBase): break finally: listener_task.cancel() # 停止与 Clash 的连接 + return EventSourceResponse(event_generator()) async def fetch_clash_data(self, endpoint: str) -> Dict: @@ -602,14 +602,14 @@ class ClashRuleProvider(_PluginBase): return schemas.Response(success=True, message="missing params") clash_version_url = f"{params.get('clash_dashboard_url')}/version" ret = await AsyncRequestUtils(accept_type="application/json", - headers={"authorization": f"Bearer {params.get('clash_dashboard_secret')}"} - ).get(clash_version_url) + headers={"authorization": f"Bearer {params.get('clash_dashboard_secret')}"} + ).get(clash_version_url) if ret is None: return schemas.Response(success=False, message="无法连接到Clash") for sub_link in (params.get('sub_links') or []): ret = await AsyncRequestUtils(accept_type="text/html", - proxies=settings.PROXY if self._proxy else None - ).get(sub_link) + proxies=settings.PROXY if self._proxy else None + ).get(sub_link) if ret is None: return schemas.Response(success=False, message=f"Unable to fetch {sub_link}") return schemas.Response(success=True, message="测试连接成功") @@ -639,7 +639,6 @@ class ClashRuleProvider(_PluginBase): "sub_url": f"{self._movie_pilot_url}/api/v1/plugin/ClashRuleProvider/config?" f"apikey={settings.API_TOKEN}"}} - def get_clash_config(self, request: Request): logger.info(f"{request.client.host} 正在获取配置") config = self.clash_config() @@ -713,7 +712,7 @@ class ClashRuleProvider(_PluginBase): res = self.delete_rule_by_priority(params.get('priority'), self._ruleset_rule_parser) if res: self.__add_notification_job( - [f"{self._ruleset_prefix}{res.action.value if isinstance(res.action, Action) else res.action}",]) + [f"{self._ruleset_prefix}{res.action.value if isinstance(res.action, Action) else res.action}", ]) else: self.delete_rule_by_priority(params.get('priority'), self._clash_rule_parser) return schemas.Response(success=True, message='') @@ -780,7 +779,7 @@ class ClashRuleProvider(_PluginBase): if params.get('type') == 'ruleset': res = self.add_rule_by_priority(params.get('rule_data'), self._ruleset_rule_parser) if res: - self.__add_notification_job([f"{self._ruleset_prefix}{params.get('rule_data').get('action')}",]) + self.__add_notification_job([f"{self._ruleset_prefix}{params.get('rule_data').get('action')}", ]) else: res = self.add_rule_by_priority(params.get('rule_data'), self._clash_rule_parser) return schemas.Response(success=bool(res), message='') @@ -852,7 +851,7 @@ class ClashRuleProvider(_PluginBase): return schemas.Response(success=True, data={'proxy_groups': []}) first_config = self._clash_configs.get(self._sub_links[0], {}) if self._sub_links else {} proxy_groups = [] - sources = ('Manual', 'Template', urlparse(self._sub_links[0]).hostname if self._sub_links else '' ,'Region') + sources = ('Manual', 'Template', urlparse(self._sub_links[0]).hostname if self._sub_links else '', 'Region') groups = (self._proxy_groups, self._clash_template.get('proxy-groups', []), first_config.get('proxy-groups', []), self.proxy_groups_by_region()) for i, group in enumerate(groups): @@ -1216,7 +1215,7 @@ class ClashRuleProvider(_PluginBase): continue tree = res.json() yaml_files = [item["path"][:item["path"].rfind('.')] for item in tree["tree"] if - item["type"] == "blob" and item['path'].endswith((".yaml", ".yml"))] + item["type"] == "blob" and item['path'].endswith((".yaml", ".yml"))] self._geo_rules[path["name"]] = yaml_files def refresh_subscriptions(self) -> Dict[str, bool]: @@ -1400,7 +1399,7 @@ class ClashRuleProvider(_PluginBase): """ # 使用模板或第一个订阅 first_config = self._clash_configs.get(self._sub_links[0], {}) if self._sub_links else {} - proxies =[] + proxies = [] if not self._clash_template: clash_config = copy.deepcopy(first_config) clash_config['proxy-groups'] = [] @@ -1418,7 +1417,7 @@ class ClashRuleProvider(_PluginBase): clash_config['rule-providers'] = clash_config.get('rule-providers') or {} clash_config['rule-providers'].update(first_config.get('rule-providers', {})) - for proxy in self.all_proxies() : + for proxy in self.all_proxies(): if any(p.get('name') == proxy.get('name', '') for p in proxies): logger.warn(f"Proxy named {proxy.get('name')!r} already exists. Skipping...") continue @@ -1457,11 +1456,11 @@ class ClashRuleProvider(_PluginBase): sub_url = (f"{self._movie_pilot_url}/api/v1/plugin/ClashRuleProvider/ruleset?" f"name={path_name}&apikey={settings.API_TOKEN}") self._rule_provider[rule_provider_name] = {"behavior": "classical", - "format": "yaml", - "interval": 3600, - "path": f"./CRP/{path_name}.yaml", - "type": "http", - "url": sub_url} + "format": "yaml", + "interval": 3600, + "path": f"./CRP/{path_name}.yaml", + "type": "http", + "url": sub_url} clash_config['rule-providers'].update(self._rule_provider) # 添加规则 for rule in self._clash_rule_parser.rules: @@ -1523,7 +1522,7 @@ class ClashRuleProvider(_PluginBase): return v6 @eventmanager.register(EventType.PluginAction) - def update_cloudflare_ips_handler(self, event:Event = None): + def update_cloudflare_ips_handler(self, event: Event = None): event_data = event.event_data if not event_data or event_data.get("action") != "update_cloudflare_ips": return @@ -1532,4 +1531,4 @@ class ClashRuleProvider(_PluginBase): ips = [ips] if isinstance(ips, list): logger.info(f"更新 Cloudflare 优选 IP ...") - self.update_best_cf_ip(ips) \ No newline at end of file + self.update_best_cf_ip(ips) diff --git a/plugins.v2/clashruleprovider/clashruleparser.py b/plugins.v2/clashruleprovider/clashruleparser.py index a947fa6..c6a5074 100644 --- a/plugins.v2/clashruleprovider/clashruleparser.py +++ b/plugins.v2/clashruleprovider/clashruleparser.py @@ -1,11 +1,11 @@ -import re -from typing import List, Dict, Any, Optional, Union, Callable, Literal -from dataclasses import dataclass -from enum import Enum -from urllib.parse import urlparse, parse_qs, unquote, parse_qsl, urlencode, urlunparse -import json import base64 import binascii +import json +import re +from dataclasses import dataclass +from enum import Enum +from typing import List, Dict, Any, Optional, Union, Callable, Literal +from urllib.parse import urlparse, parse_qs, unquote, parse_qsl from pydantic import BaseModel, Field, validator, HttpUrl @@ -56,9 +56,11 @@ class RuleProvider(BaseModel): raise ValueError("mrs format only supports 'domain' or 'ipcidr' behavior") return v + class RuleProviders(BaseModel): __root__: dict[str, RuleProvider] + class ProxyGroupBase(BaseModel): """ 包含所有代理组类型共有的通用字段。 @@ -97,7 +99,6 @@ class ProxyGroupBase(BaseModel): hidden: Optional[bool] = Field(False, description="Hides the proxy group in the API.") icon: Optional[str] = Field(None, description="Icon string for the proxy group, for UI use.") - @validator('expected_status', allow_reuse=True) def validate_expected_status(cls, v: Optional[str]) -> Optional[str]: if v is None or v == '*': @@ -109,25 +110,31 @@ class ProxyGroupBase(BaseModel): for part in parts: if '-' in part: start, end = part.split('-') - if not (start.isdigit() and end.isdigit() and 100 <= int(start) < 600 and 100 <= int(end) < 600 and int(start) <= int(end)): + if not (start.isdigit() and end.isdigit() and 100 <= int(start) < 600 and 100 <= int(end) < 600 and int( + start) <= int(end)): raise ValueError(f"Invalid status code range: {part}") elif not (part.isdigit() and 100 <= int(part) < 600): raise ValueError(f"Invalid status code: {part}") return v + class SelectGroup(ProxyGroupBase): type: Literal['select'] + class RelayGroup(ProxyGroupBase): type: Literal['relay'] + class FallbackGroup(ProxyGroupBase): type: Literal['fallback'] + class UrlTestGroup(ProxyGroupBase): type: Literal['url-test'] tolerance: Optional[int] = Field(None, description="proxies switch tolerance, measured in milliseconds (ms).") + class LoadBalanceGroup(ProxyGroupBase): type: Literal['load-balance'] strategy: Optional[Literal['round-robin', 'consistent-hashing', 'sticky-sessions']] = Field( @@ -135,12 +142,15 @@ class LoadBalanceGroup(ProxyGroupBase): description="Load balancing strategy." ) + # --- Discriminated Union --- ProxyGroupUnion = Union[SelectGroup, RelayGroup, FallbackGroup, UrlTestGroup, LoadBalanceGroup] + class ProxyGroup(BaseModel): __root__: ProxyGroupUnion + class AdditionalParam(Enum): NO_RESOLVE = 'no-resolve' SRC = 'src' @@ -860,7 +870,7 @@ class Converter: headers = {"User-Agent": Converter.user_agent} if host: headers["Host"] = host - ws_opts:Dict[str, Any] = { "path": path, "headers": headers } + ws_opts: Dict[str, Any] = {"path": path, "headers": headers} try: parsed_path = urlparse(path) q = dict(parse_qsl(parsed_path.query)) @@ -1226,4 +1236,4 @@ class Converter: if not skip_exception: raise ValueError("convert v2ray subscribe error: format invalid") - return proxies \ No newline at end of file + return proxies