From 2d0e6d125d67fc75bab996695435a4f28d8ca649 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 4 Dec 2023 16:43:11 +0800 Subject: [PATCH] =?UTF-8?q?fix=20#21=20ChatGPT=E6=94=AF=E6=8C=81=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- package.json | 2 +- plugins/chatgpt/__init__.py | 34 ++++++++++++++++++++++++++++------ plugins/chatgpt/openai.py | 10 ++++++---- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/package.json b/package.json index b5371c1..f8fd269 100644 --- a/package.json +++ b/package.json @@ -146,7 +146,7 @@ "ChatGPT": { "name": "ChatGPT", "description": "消息交互支持与ChatGPT对话。", - "version": "1.1", + "version": "1.2", "icon": "Chatgpt_A.png", "author": "jxxghp", "level": 1 diff --git a/plugins/chatgpt/__init__.py b/plugins/chatgpt/__init__.py index 048a6a6..5bf95b7 100644 --- a/plugins/chatgpt/__init__.py +++ b/plugins/chatgpt/__init__.py @@ -16,7 +16,7 @@ class ChatGPT(_PluginBase): # 插件图标 plugin_icon = "Chatgpt_A.png" # 插件版本 - plugin_version = "1.1" + plugin_version = "1.2" # 插件作者 plugin_author = "jxxghp" # 作者主页 @@ -35,6 +35,7 @@ class ChatGPT(_PluginBase): _recognize = False _openai_url = None _openai_key = None + _model = None def init_plugin(self, config: dict = None): if config: @@ -43,8 +44,11 @@ class ChatGPT(_PluginBase): self._recognize = config.get("recognize") self._openai_url = config.get("openai_url") self._openai_key = config.get("openai_key") - self.openai = OpenAi(api_key=self._openai_key, api_url=self._openai_url, - proxy=settings.PROXY if self._proxy else None) + self._model = config.get("model") + if self._openai_url and self._openai_key: + self.openai = OpenAi(api_key=self._openai_key, api_url=self._openai_url, + proxy=settings.PROXY if self._proxy else None, + model=self._model) def get_state(self) -> bool: return self._enabled @@ -124,7 +128,7 @@ class ChatGPT(_PluginBase): 'component': 'VCol', 'props': { 'cols': 12, - 'md': 6 + 'md': 4 }, 'content': [ { @@ -141,7 +145,7 @@ class ChatGPT(_PluginBase): 'component': 'VCol', 'props': { 'cols': 12, - 'md': 6 + 'md': 4 }, 'content': [ { @@ -152,6 +156,23 @@ class ChatGPT(_PluginBase): } } ] + }, + { + 'component': 'VCol', + 'props': { + 'cols': 12, + 'md': 4 + }, + 'content': [ + { + 'component': 'VTextField', + 'props': { + 'model': 'model', + 'label': '自定义模型', + 'placeholder': 'gpt-3.5-turbo', + } + } + ] } ] }, @@ -162,7 +183,8 @@ class ChatGPT(_PluginBase): "proxy": False, "recognize": False, "openai_url": "https://api.openai.com", - "openai_key": "" + "openai_key": "", + "model": "gpt-3.5-turbo" } def get_page(self) -> List[dict]: diff --git a/plugins/chatgpt/openai.py b/plugins/chatgpt/openai.py index 3613926..937ecea 100644 --- a/plugins/chatgpt/openai.py +++ b/plugins/chatgpt/openai.py @@ -11,14 +11,17 @@ OpenAISessionCache = Cache(maxsize=100, ttl=3600, timer=time.time, default=None) class OpenAi: _api_key: str = None _api_url: str = None + _model: str = "gpt-3.5-turbo" - def __init__(self, api_key: str = None, api_url: str = None, proxy: dict = None): + def __init__(self, api_key: str = None, api_url: str = None, proxy: dict = None, model: str = None): self._api_key = api_key self._api_url = api_url openai.api_base = self._api_url + "/v1" openai.api_key = self._api_key if proxy and proxy.get("https"): openai.proxy = proxy.get("https") + if model: + self._model = model def get_state(self) -> bool: return True if self._api_key else False @@ -65,8 +68,7 @@ class OpenAi: OpenAISessionCache.set(session_id, seasion) return seasion - @staticmethod - def __get_model(message: Union[str, List[dict]], + def __get_model(self, message: Union[str, List[dict]], prompt: str = None, user: str = "MoviePilot", **kwargs): @@ -93,7 +95,7 @@ class OpenAi: } ] return openai.ChatCompletion.create( - model="gpt-3.5-turbo", + model=self._model, user=user, messages=message, **kwargs