Files
MoviePilot-Plugins/plugins.v2/lexiannot/query_gemini.py
2025-06-09 23:59:43 +08:00

220 lines
6.7 KiB
Python

import sys
import json
import time
from typing import List, Dict, Any, Type, Union
from pydantic import BaseModel, ValidationError
class Context(BaseModel):
original_text: str
class Vocabulary(BaseModel):
lemma: str
Chinese: str
class VocabularyTranslationTask(BaseModel):
index: int
vocabulary: List[Vocabulary]
context: Context
class DialogueTranslationTask(BaseModel):
index: int
original_text: str
Chinese: str
class GeminiResponse(BaseModel):
tasks: List[Union[VocabularyTranslationTask, DialogueTranslationTask]]
total_token_count: int
success: bool
message: str = ""
def validate_input_data(request_data: Dict[str, Any]) -> None:
"""Validate the input data structure"""
if not isinstance(request_data, dict):
raise ValueError("Input data must be a dictionary")
if "tasks" not in request_data:
raise ValueError("Missing 'tasks' in input data")
if "params" not in request_data:
raise ValueError("Missing 'params' in input data")
params = request_data["params"]
required_params = ["api_key", "system_instruction", "schema"]
for param in required_params:
if param not in params:
raise ValueError(f"Missing required parameter: {param}")
def get_task_schema(schema_name: str) -> Type[BaseModel]:
"""Get the appropriate schema class based on the schema name"""
schema_map = {
'DialogueTranslationTask': DialogueTranslationTask,
'VocabularyTranslationTask': VocabularyTranslationTask
}
if schema_name not in schema_map:
raise ValueError(f"Unknown schema name: {schema_name}")
return schema_map[schema_name]
def query_gemini(
api_key: str,
translation_tasks: List[Dict[str, Any]],
task_schema: Type[Union[VocabularyTranslationTask, DialogueTranslationTask]],
system_instruction: str,
gemini_model: str = "gemini-2.0-flash",
temperature: float = 0.3,
max_retries: int = 3,
retry_delay: int = 10
) -> GeminiResponse:
"""
Query the Gemini API for translation tasks with retry logic.
Args:
api_key: Gemini API key
translation_tasks: List of translation tasks
task_schema: Pydantic model for the task type
system_instruction: System instruction for the model
gemini_model: Model name to use
temperature: Generation temperature
max_retries: Number of retry attempts
retry_delay: Delay between retries in seconds
Returns:
GeminiResponse containing the results
"""
from google import genai
from google.genai import types
from google.genai.types import SchemaUnion
client = genai.Client(api_key=api_key)
messages = []
translation_res = []
total_token_count = 0
# Validate input tasks before sending to API
try:
translation_res = [task_schema(**task) for task in translation_tasks]
except ValidationError as e:
return GeminiResponse(
tasks=[],
total_token_count=0,
success=False,
message=f"Input validation failed: {str(e)}"
)
for attempt in range(1, max_retries + 1):
try:
response = client.models.generate_content(
model=gemini_model,
contents=json.dumps(translation_tasks, ensure_ascii=False),
config=types.GenerateContentConfig(
system_instruction=system_instruction,
response_mime_type="application/json",
response_schema=list[task_schema],
temperature=temperature
),
)
if not response.parsed:
raise ValueError("Empty response from Gemini API")
translation_res = response.parsed
total_token_count = response.usage_metadata.total_token_count
return GeminiResponse(
tasks=translation_res,
total_token_count=total_token_count,
success=True
)
except Exception as e:
messages.append(f"Attempt {attempt} failed: {str(e)}")
if attempt < max_retries:
time.sleep(retry_delay)
return GeminiResponse(
tasks=[],
total_token_count=0,
success=False,
message="All retry attempts failed. " + "\n".join(messages)
)
def main():
try:
# Read and parse input
'''{
"tasks": [{
"index": 0,
"original_text": "That was eight years ago.",
"Chinese": ""
}, {
"index": 1,
"original_text": "Much has changed.",
"Chinese": ""
}],
"params": {
"api_key": "",
"system_instruction": "You are an expert translator. You will be given a list of dialogue translation tasks in JSON format. For each entry, provide the most appropriate translation in Simplified Chinese based on the context. \\nOnly complete the `Chinese` field. Do not include pinyin, explanations, or any additional information.",
"schema": "DialogueTranslationTask"
}
}'''
input_text = sys.stdin.read()
if not input_text:
raise ValueError("No input provided")
request_data = json.loads(input_text)
validate_input_data(request_data)
# Extract parameters
tasks = request_data["tasks"]
params = request_data["params"]
# Get schema and make API call
schema = get_task_schema(params["schema"])
response = query_gemini(
api_key=params["api_key"],
translation_tasks=tasks,
task_schema=schema,
system_instruction=params["system_instruction"],
gemini_model=params.get("model", "gemini-2.0-flash"),
temperature=float(params.get("temperature", 0.3)),
max_retries=int(params.get("max_retries", 3))
)
# Prepare output
if response.success:
result = {
"success": True,
"data": {
"tasks": [task.model_dump() for task in response.tasks],
"total_token_count": response.total_token_count
}
}
else:
result = {
"success": False,
"message": response.message
}
print(json.dumps(result, ensure_ascii=False))
except json.JSONDecodeError as e:
error = {
"success": False,
"message": f"Invalid JSON input: {str(e)}"
}
print(json.dumps(error))
except Exception as e:
error = {
"success": False,
"message": f"Unexpected error: {str(e)}"
}
print(json.dumps(error))
if __name__ == "__main__":
main()