mirror of
https://github.com/hicccc77/WeFlow.git
synced 2026-03-24 23:06:51 +00:00
307 lines
8.6 KiB
TypeScript
307 lines
8.6 KiB
TypeScript
import { parentPort, workerData } from 'worker_threads'
|
|
import { join } from 'path'
|
|
import { mkdirSync } from 'fs'
|
|
import * as lancedb from '@lancedb/lancedb'
|
|
import { pipeline, env } from '@xenova/transformers'
|
|
import { wcdbService } from './services/wcdbService'
|
|
import { mapRowToCloneMessage, CloneMessage, CloneRole } from './services/cloneMessageUtils'
|
|
|
|
interface WorkerConfig {
|
|
resourcesPath?: string
|
|
userDataPath?: string
|
|
logEnabled?: boolean
|
|
embeddingModel?: string
|
|
}
|
|
|
|
type WorkerRequest =
|
|
| { id: string; type: 'index'; payload: IndexPayload }
|
|
| { id: string; type: 'query'; payload: QueryPayload }
|
|
|
|
interface IndexPayload {
|
|
sessionId: string
|
|
dbPath: string
|
|
decryptKey: string
|
|
myWxid: string
|
|
batchSize?: number
|
|
chunkGapSeconds?: number
|
|
maxChunkChars?: number
|
|
maxChunkMessages?: number
|
|
reset?: boolean
|
|
}
|
|
|
|
interface QueryPayload {
|
|
sessionId: string
|
|
keyword: string
|
|
topK?: number
|
|
roleFilter?: CloneRole
|
|
}
|
|
|
|
const config = workerData as WorkerConfig
|
|
process.env.WEFLOW_WORKER = '1'
|
|
if (config.resourcesPath) {
|
|
process.env.WCDB_RESOURCES_PATH = config.resourcesPath
|
|
}
|
|
|
|
wcdbService.setPaths(config.resourcesPath || '', config.userDataPath || '')
|
|
wcdbService.setLogEnabled(config.logEnabled === true)
|
|
|
|
env.allowRemoteModels = true
|
|
if (env.backends?.onnx) {
|
|
env.backends.onnx.wasm.enabled = false
|
|
}
|
|
|
|
const embeddingModel = config.embeddingModel || 'Xenova/bge-small-zh-v1.5'
|
|
let embedder: any | null = null
|
|
|
|
async function ensureEmbedder() {
|
|
if (embedder) return embedder
|
|
if (config.userDataPath) {
|
|
env.cacheDir = join(config.userDataPath, 'transformers')
|
|
}
|
|
embedder = await pipeline('feature-extraction', embeddingModel)
|
|
return embedder
|
|
}
|
|
|
|
function getMemoryDir(sessionId: string): string {
|
|
const safeId = sessionId.replace(/[\\/:"*?<>|]+/g, '_')
|
|
const base = config.userDataPath || process.cwd()
|
|
const dir = join(base, 'clone_memory', safeId)
|
|
mkdirSync(dir, { recursive: true })
|
|
return dir
|
|
}
|
|
|
|
async function getTable(sessionId: string, reset?: boolean) {
|
|
const dir = getMemoryDir(sessionId)
|
|
const db = await lancedb.connect(dir)
|
|
const tables = await db.tableNames()
|
|
if (reset && tables.includes('messages')) {
|
|
await db.dropTable('messages')
|
|
}
|
|
const hasTable = tables.includes('messages') && !reset
|
|
return { db, hasTable }
|
|
}
|
|
|
|
function shouldSkipContent(text: string): boolean {
|
|
if (!text) return true
|
|
if (text === '[图片]' || text === '[语音]' || text === '[视频]' || text === '[表情]' || text === '[分享]') {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
function chunkMessages(
|
|
messages: CloneMessage[],
|
|
gapSeconds: number,
|
|
maxChars: number,
|
|
maxMessages: number
|
|
) {
|
|
const chunks: Array<{
|
|
role: CloneRole
|
|
content: string
|
|
tsStart: number
|
|
tsEnd: number
|
|
messageCount: number
|
|
}> = []
|
|
let current: typeof chunks[number] | null = null
|
|
|
|
for (const msg of messages) {
|
|
if (shouldSkipContent(msg.content)) continue
|
|
if (!current) {
|
|
current = {
|
|
role: msg.role,
|
|
content: msg.content,
|
|
tsStart: msg.createTime,
|
|
tsEnd: msg.createTime,
|
|
messageCount: 1
|
|
}
|
|
continue
|
|
}
|
|
|
|
const gap = msg.createTime - current.tsEnd
|
|
const nextContent = `${current.content}\n${msg.content}`
|
|
const roleChanged = msg.role !== current.role
|
|
if (roleChanged || gap > gapSeconds || nextContent.length > maxChars || current.messageCount >= maxMessages) {
|
|
chunks.push(current)
|
|
current = {
|
|
role: msg.role,
|
|
content: msg.content,
|
|
tsStart: msg.createTime,
|
|
tsEnd: msg.createTime,
|
|
messageCount: 1
|
|
}
|
|
continue
|
|
}
|
|
|
|
current.content = nextContent
|
|
current.tsEnd = msg.createTime
|
|
current.messageCount += 1
|
|
}
|
|
|
|
if (current) {
|
|
chunks.push(current)
|
|
}
|
|
|
|
return chunks
|
|
}
|
|
|
|
async function embedTexts(texts: string[]) {
|
|
const model = await ensureEmbedder()
|
|
const output = await model(texts, { pooling: 'mean', normalize: true })
|
|
if (Array.isArray(output)) return output
|
|
if (output?.tolist) return output.tolist()
|
|
return []
|
|
}
|
|
|
|
async function gatherDebugInfo(table: any) {
|
|
try {
|
|
const rowCount = await table.countRows()
|
|
const sample = await table.limit(3).toArray()
|
|
return { rowCount, sample }
|
|
} catch {
|
|
return {}
|
|
}
|
|
}
|
|
|
|
async function handleIndex(requestId: string, payload: IndexPayload) {
|
|
const {
|
|
sessionId,
|
|
dbPath,
|
|
decryptKey,
|
|
myWxid,
|
|
batchSize = 200,
|
|
chunkGapSeconds = 600,
|
|
maxChunkChars = 400,
|
|
maxChunkMessages = 20,
|
|
reset = false
|
|
} = payload
|
|
|
|
const openOk = await wcdbService.open(dbPath, decryptKey, myWxid)
|
|
if (!openOk) {
|
|
throw new Error('WCDB open failed')
|
|
}
|
|
|
|
const cursorResult = await wcdbService.openMessageCursorLite(sessionId, batchSize, true, 0, 0)
|
|
if (!cursorResult.success || !cursorResult.cursor) {
|
|
throw new Error(cursorResult.error || 'cursor open failed')
|
|
}
|
|
|
|
const { db, hasTable } = await getTable(sessionId, reset)
|
|
let table = hasTable ? await db.openTable('messages') : null
|
|
let cursor = cursorResult.cursor
|
|
let hasMore = true
|
|
let chunkId = 0
|
|
let totalMessages = 0
|
|
let totalChunks = 0
|
|
|
|
try {
|
|
while (hasMore) {
|
|
const batchResult = await wcdbService.fetchMessageBatch(cursor)
|
|
if (!batchResult.success || !batchResult.rows) {
|
|
throw new Error(batchResult.error || 'fetch batch failed')
|
|
}
|
|
|
|
totalMessages += batchResult.rows.length
|
|
const messages: CloneMessage[] = []
|
|
for (const row of batchResult.rows) {
|
|
const msg = mapRowToCloneMessage(row, myWxid)
|
|
if (msg) messages.push(msg)
|
|
}
|
|
|
|
const chunks = chunkMessages(messages, chunkGapSeconds, maxChunkChars, maxChunkMessages)
|
|
if (chunks.length > 0) {
|
|
const embeddings = await embedTexts(chunks.map((c) => c.content))
|
|
if (embeddings.length !== chunks.length) {
|
|
throw new Error('embedding size mismatch')
|
|
}
|
|
const rows = chunks.map((chunk, idx) => ({
|
|
id: `${sessionId}-${chunkId + idx}`,
|
|
sessionId,
|
|
role: chunk.role,
|
|
content: chunk.content,
|
|
tsStart: chunk.tsStart,
|
|
tsEnd: chunk.tsEnd,
|
|
messageCount: chunk.messageCount,
|
|
embedding: new Float32Array(embeddings[idx] || [])
|
|
}))
|
|
if (!table) {
|
|
table = await db.createTable('messages', rows)
|
|
} else {
|
|
await table.add(rows)
|
|
}
|
|
chunkId += chunks.length
|
|
totalChunks += chunks.length
|
|
}
|
|
|
|
hasMore = batchResult.hasMore === true
|
|
parentPort?.postMessage({
|
|
type: 'event',
|
|
event: 'clone:indexProgress',
|
|
data: { requestId, totalMessages, totalChunks, hasMore }
|
|
})
|
|
}
|
|
} finally {
|
|
await wcdbService.closeMessageCursor(cursor)
|
|
wcdbService.close()
|
|
}
|
|
|
|
const debug = await gatherDebugInfo(table)
|
|
return { success: true, totalMessages, totalChunks, debug }
|
|
}
|
|
|
|
async function handleQuery(payload: QueryPayload) {
|
|
const { sessionId, keyword, topK = 5, roleFilter } = payload
|
|
const { db, hasTable } = await getTable(sessionId, false)
|
|
if (!hasTable) {
|
|
return { success: false, error: 'memory table not found' }
|
|
}
|
|
const table = await db.openTable('messages')
|
|
const embeddings = await embedTexts([keyword])
|
|
if (!embeddings.length || !embeddings[0]) {
|
|
return { success: false, error: 'embedding failed' }
|
|
}
|
|
const query = table.search(new Float32Array(embeddings[0] || [])).limit(topK)
|
|
const filtered = roleFilter ? query.where(`role = '${roleFilter}'`) : query
|
|
let rows = await filtered.toArray()
|
|
let usedFallback = false
|
|
|
|
if (rows.length === 0) {
|
|
try {
|
|
usedFallback = true
|
|
const lowerKeyword = keyword.trim().toLowerCase()
|
|
const all = await table.toArray()
|
|
rows = all.filter((row) => {
|
|
const content = String(row.content || '').toLowerCase()
|
|
return content.includes(lowerKeyword)
|
|
}).slice(0, topK)
|
|
} catch {
|
|
// fallback remain empty
|
|
}
|
|
}
|
|
|
|
const debug = {
|
|
rowsFound: rows.length,
|
|
usedFallback,
|
|
sample: rows.slice(0, 2)
|
|
}
|
|
|
|
return { success: true, results: rows, debug }
|
|
}
|
|
|
|
parentPort?.on('message', async (request: WorkerRequest) => {
|
|
try {
|
|
if (request.type === 'index') {
|
|
const data = await handleIndex(request.id, request.payload)
|
|
parentPort?.postMessage({ type: 'response', id: request.id, ok: true, data })
|
|
return
|
|
}
|
|
if (request.type === 'query') {
|
|
const data = await handleQuery(request.payload)
|
|
parentPort?.postMessage({ type: 'response', id: request.id, ok: true, data })
|
|
return
|
|
}
|
|
parentPort?.postMessage({ type: 'response', id: request.id, ok: false, error: 'unknown request' })
|
|
} catch (err) {
|
|
parentPort?.postMessage({ type: 'response', id: request.id, ok: false, error: String(err) })
|
|
}
|
|
})
|