mirror of
https://github.com/d0zingcat/MoviePilot-Plugins.git
synced 2026-05-16 15:09:22 +00:00
220 lines
6.7 KiB
Python
220 lines
6.7 KiB
Python
import sys
|
|
import json
|
|
import time
|
|
from typing import List, Dict, Any, Type, Union
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
|
|
class Context(BaseModel):
|
|
original_text: str
|
|
|
|
|
|
class Vocabulary(BaseModel):
|
|
lemma: str
|
|
Chinese: str
|
|
|
|
|
|
class VocabularyTranslationTask(BaseModel):
|
|
index: int
|
|
vocabulary: List[Vocabulary]
|
|
context: Context
|
|
|
|
|
|
class DialogueTranslationTask(BaseModel):
|
|
index: int
|
|
original_text: str
|
|
Chinese: str
|
|
|
|
|
|
class GeminiResponse(BaseModel):
|
|
tasks: List[Union[VocabularyTranslationTask, DialogueTranslationTask]]
|
|
total_token_count: int
|
|
success: bool
|
|
message: str = ""
|
|
|
|
|
|
def validate_input_data(request_data: Dict[str, Any]) -> None:
|
|
"""Validate the input data structure"""
|
|
if not isinstance(request_data, dict):
|
|
raise ValueError("Input data must be a dictionary")
|
|
if "tasks" not in request_data:
|
|
raise ValueError("Missing 'tasks' in input data")
|
|
if "params" not in request_data:
|
|
raise ValueError("Missing 'params' in input data")
|
|
|
|
params = request_data["params"]
|
|
required_params = ["api_key", "system_instruction", "schema"]
|
|
for param in required_params:
|
|
if param not in params:
|
|
raise ValueError(f"Missing required parameter: {param}")
|
|
|
|
|
|
def get_task_schema(schema_name: str) -> Type[BaseModel]:
|
|
"""Get the appropriate schema class based on the schema name"""
|
|
schema_map = {
|
|
'DialogueTranslationTask': DialogueTranslationTask,
|
|
'VocabularyTranslationTask': VocabularyTranslationTask
|
|
}
|
|
if schema_name not in schema_map:
|
|
raise ValueError(f"Unknown schema name: {schema_name}")
|
|
return schema_map[schema_name]
|
|
|
|
|
|
def query_gemini(
|
|
api_key: str,
|
|
translation_tasks: List[Dict[str, Any]],
|
|
task_schema: Type[Union[VocabularyTranslationTask, DialogueTranslationTask]],
|
|
system_instruction: str,
|
|
gemini_model: str = "gemini-2.0-flash",
|
|
temperature: float = 0.3,
|
|
max_retries: int = 3,
|
|
retry_delay: int = 10
|
|
) -> GeminiResponse:
|
|
"""
|
|
Query the Gemini API for translation tasks with retry logic.
|
|
|
|
Args:
|
|
api_key: Gemini API key
|
|
translation_tasks: List of translation tasks
|
|
task_schema: Pydantic model for the task type
|
|
system_instruction: System instruction for the model
|
|
gemini_model: Model name to use
|
|
temperature: Generation temperature
|
|
max_retries: Number of retry attempts
|
|
retry_delay: Delay between retries in seconds
|
|
|
|
Returns:
|
|
GeminiResponse containing the results
|
|
"""
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.genai.types import SchemaUnion
|
|
client = genai.Client(api_key=api_key)
|
|
messages = []
|
|
translation_res = []
|
|
total_token_count = 0
|
|
|
|
# Validate input tasks before sending to API
|
|
try:
|
|
translation_res = [task_schema(**task) for task in translation_tasks]
|
|
except ValidationError as e:
|
|
return GeminiResponse(
|
|
tasks=[],
|
|
total_token_count=0,
|
|
success=False,
|
|
message=f"Input validation failed: {str(e)}"
|
|
)
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = client.models.generate_content(
|
|
model=gemini_model,
|
|
contents=json.dumps(translation_tasks, ensure_ascii=False),
|
|
config=types.GenerateContentConfig(
|
|
system_instruction=system_instruction,
|
|
response_mime_type="application/json",
|
|
response_schema=list[task_schema],
|
|
temperature=temperature
|
|
),
|
|
)
|
|
|
|
if not response.parsed:
|
|
raise ValueError("Empty response from Gemini API")
|
|
|
|
translation_res = response.parsed
|
|
total_token_count = response.usage_metadata.total_token_count
|
|
return GeminiResponse(
|
|
tasks=translation_res,
|
|
total_token_count=total_token_count,
|
|
success=True
|
|
)
|
|
|
|
except Exception as e:
|
|
messages.append(f"Attempt {attempt} failed: {str(e)}")
|
|
if attempt < max_retries:
|
|
time.sleep(retry_delay)
|
|
|
|
return GeminiResponse(
|
|
tasks=[],
|
|
total_token_count=0,
|
|
success=False,
|
|
message="All retry attempts failed. " + "\n".join(messages)
|
|
)
|
|
|
|
|
|
def main():
|
|
try:
|
|
# Read and parse input
|
|
'''{
|
|
"tasks": [{
|
|
"index": 0,
|
|
"original_text": "That was eight years ago.",
|
|
"Chinese": ""
|
|
}, {
|
|
"index": 1,
|
|
"original_text": "Much has changed.",
|
|
"Chinese": ""
|
|
}],
|
|
"params": {
|
|
"api_key": "",
|
|
"system_instruction": "You are an expert translator. You will be given a list of dialogue translation tasks in JSON format. For each entry, provide the most appropriate translation in Simplified Chinese based on the context. \\nOnly complete the `Chinese` field. Do not include pinyin, explanations, or any additional information.",
|
|
"schema": "DialogueTranslationTask"
|
|
}
|
|
}'''
|
|
input_text = sys.stdin.read()
|
|
if not input_text:
|
|
raise ValueError("No input provided")
|
|
|
|
request_data = json.loads(input_text)
|
|
validate_input_data(request_data)
|
|
|
|
# Extract parameters
|
|
tasks = request_data["tasks"]
|
|
params = request_data["params"]
|
|
|
|
# Get schema and make API call
|
|
schema = get_task_schema(params["schema"])
|
|
response = query_gemini(
|
|
api_key=params["api_key"],
|
|
translation_tasks=tasks,
|
|
task_schema=schema,
|
|
system_instruction=params["system_instruction"],
|
|
gemini_model=params.get("model", "gemini-2.0-flash"),
|
|
temperature=float(params.get("temperature", 0.3)),
|
|
max_retries=int(params.get("max_retries", 3))
|
|
)
|
|
|
|
# Prepare output
|
|
if response.success:
|
|
result = {
|
|
"success": True,
|
|
"data": {
|
|
"tasks": [task.model_dump() for task in response.tasks],
|
|
"total_token_count": response.total_token_count
|
|
}
|
|
}
|
|
else:
|
|
result = {
|
|
"success": False,
|
|
"message": response.message
|
|
}
|
|
|
|
print(json.dumps(result, ensure_ascii=False))
|
|
|
|
except json.JSONDecodeError as e:
|
|
error = {
|
|
"success": False,
|
|
"message": f"Invalid JSON input: {str(e)}"
|
|
}
|
|
print(json.dumps(error))
|
|
except Exception as e:
|
|
error = {
|
|
"success": False,
|
|
"message": f"Unexpected error: {str(e)}"
|
|
}
|
|
print(json.dumps(error))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |