created embeddings

This commit is contained in:
George Powell
2025-12-26 00:33:16 -05:00
parent d8cff2ff7a
commit 0daefcb080
8 changed files with 443 additions and 4 deletions

View File

@@ -0,0 +1,98 @@
import { pipeline } from '@xenova/transformers';
import type { FeatureExtractionPipeline, Tensor } from '@xenova/transformers';
import fs from 'fs/promises';
let extractor: FeatureExtractionPipeline | null = null;
const EMBEDDING_DIM = 384;
let verseEmbeddings: Float32Array[] = [];
let verses: Array<{ text: string; book: string; chapter: number; verse: number }> = [];
// Initialize once on server startup
export async function initializeEmbeddings(bibleVerses: Array<{ text: string; book: string; chapter: number; verse: number; }>) {
if (extractor) return; // Already initialized
console.log('Loading embedding model...');
extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L12-v2');
// main ^
// extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); not used
// extractor = await pipeline('feature-extraction', 'Xenova/gte-base'); // testing
verses = bibleVerses;
const CACHE_PATH = './embeddings-cache-L12.json';
// main ^
// const CACHE_PATH = './embeddings-cache-L6.json'; not used
// const CACHE_PATH = './embeddings-cache-GTE-base.json'; // testing
try {
await fs.access(CACHE_PATH);
const cachedStr = await fs.readFile(CACHE_PATH, 'utf-8');
const cached = JSON.parse(cachedStr);
verseEmbeddings = cached.embeddings.map((arr: number[]) => Float32Array.from(arr));
verses = cached.verses;
console.log('Loaded embeddings from cache!');
return;
} catch {
console.log('No cache found, computing embeddings...');
}
console.log(`Encoding ${verses.length} verses in small batches to manage memory...`);
const BATCH_SIZE = 128;
const texts = verses.map((v) => v.text);
verseEmbeddings = [];
for (let start = 0; start < texts.length; start += BATCH_SIZE) {
const batchTexts = texts.slice(start, start + BATCH_SIZE);
console.log(`Processing batch ${Math.floor(start / BATCH_SIZE) + 1} (${batchTexts.length} verses)...`);
const output = await extractor!(batchTexts, { pooling: 'mean', normalize: true });
const data = output.data as Float32Array;
const embeddingDim = EMBEDDING_DIM;
for (let k = 0; k < batchTexts.length; k++) {
verseEmbeddings.push(Float32Array.from(data.slice(k * embeddingDim, (k + 1) * embeddingDim)));
}
}
// Save to cache
const embeddingsData = {
embeddings: verseEmbeddings.map(e => Array.from(e)),
verses: verses
};
await fs.writeFile(CACHE_PATH, JSON.stringify(embeddingsData));
console.log('Embeddings computed and cached to disk!');
}
function cosineSimilarity(a: Float32Array, b: Float32Array): number {
let sum = 0;
for (let i = 0; i < a.length; i++) {
sum += a[i] * b[i];
}
return sum;
}
export async function findSimilarVerses(sentence: string, topK: number = 10) {
if (!extractor || verseEmbeddings.length === 0) {
throw new Error('Embeddings not initialized');
}
if (verseEmbeddings.length !== verses.length) {
throw new Error(`Embeddings/verses length mismatch: ${verseEmbeddings.length} != ${verses.length}`);
}
// Encode query sentence
const queryOutput = await extractor(sentence, { pooling: 'mean', normalize: true });
const queryEmbedding = queryOutput.data as Float32Array;
if (queryEmbedding.length !== EMBEDDING_DIM) {
throw new Error(`Query embedding dim mismatch: ${queryEmbedding.length} != ${EMBEDDING_DIM}`);
}
// Calculate similarities
const scores = verses.map((verse, idx) => ({
...verse,
score: cosineSimilarity(queryEmbedding, verseEmbeddings[idx])
}));
// Sort and return top K
return scores
.sort((a, b) => b.score - a.score)
.slice(0, topK);
}