import time from typing import Generic, List, TypeVar from google import genai from google.genai import types from pydantic import BaseModel class Context(BaseModel): original_text: str class Vocabulary(BaseModel): lemma: str Chinese: str class TaskBase(BaseModel): id: str class VocabularyTranslationTask(TaskBase): vocabulary: List[Vocabulary] context: Context index: int class DialogueTranslationTask(TaskBase): original_text: str Chinese: str index: int T = TypeVar("T", bound=TaskBase) class TranslationTasks(BaseModel, Generic[T]): tasks: List[T] class GeminiResponse(BaseModel, Generic[T]): tasks: List[T] total_token_count: int success: bool message: str = "" def translate( api_key: str, translation_tasks: TranslationTasks[T], system_instruction: str, gemini_model: str = "gemini-2.0-flash", temperature: float = 0.3, max_retries: int = 3, retry_delay: int = 10, ) -> GeminiResponse[T]: """ Query the Gemini API for translation tasks with retry logic. :param api_key: Gemini API key :param translation_tasks: Translation tasks :param system_instruction: System instruction :param gemini_model: Model name to use :param temperature: Generation temperature :param max_retries: Number of retry attempts :param retry_delay: Delay between retries in seconds returns: GeminiResponse containing the results """ messages = [] response_schema = type(translation_tasks) for attempt in range(1, max_retries + 1): try: client = genai.Client(api_key=api_key) response = client.models.generate_content( model=gemini_model, contents=translation_tasks.model_dump_json(), config=types.GenerateContentConfig( system_instruction=system_instruction, response_mime_type="application/json", response_schema=response_schema, temperature=temperature, ), ) if not response.parsed: raise ValueError("Empty response from Gemini API") translation_res = response.parsed total_token_count = response.usage_metadata.total_token_count return GeminiResponse( tasks=translation_res.tasks, total_token_count=total_token_count or 0, success=True, ) except Exception as e: messages.append(f"Attempt {attempt} failed: {str(e)}") if attempt < max_retries: time.sleep(attempt*retry_delay) return GeminiResponse( tasks=[], total_token_count=0, success=False, message="All retry attempts failed. " + "\n".join(messages), )