Bladeren bron

Merge pull request #9 from burke/single-tokenize

Optimize chunking to tokenize once per document
Tobias Lütke 4 maanden geleden
bovenliggende
commit
fe0fd0859b
1 gewijzigde bestanden met toevoegingen van 23 en 54 verwijderingen
  1. 23 54
      src/store.ts

+ 23 - 54
src/store.ts

@@ -1057,49 +1057,28 @@ export async function chunkDocumentByTokens(
 ): Promise<{ text: string; pos: number; tokens: number }[]> {
   const llm = getDefaultLlamaCpp();
 
-  // For small documents, check if we need chunking at all
-  const totalTokens = await llm.countTokens(content);
+  // Tokenize once upfront
+  const allTokens = await llm.tokenize(content);
+  const totalTokens = allTokens.length;
+
   if (totalTokens <= maxTokens) {
     return [{ text: content, pos: 0, tokens: totalTokens }];
   }
 
   const chunks: { text: string; pos: number; tokens: number }[] = [];
-  let charPos = 0;
+  const step = maxTokens - overlapTokens;
+  const avgCharsPerToken = content.length / totalTokens;
+  let tokenPos = 0;
 
-  while (charPos < content.length) {
-    // Binary search to find the right chunk end position
-    // Start with an estimate based on average tokens per char
-    const avgCharsPerToken = content.length / totalTokens;
-    let estimatedEnd = Math.min(charPos + Math.floor(maxTokens * avgCharsPerToken * 1.1), content.length);
-
-    // Get token count for this slice
-    let slice = content.slice(charPos, estimatedEnd);
-    let sliceTokens = await llm.countTokens(slice);
-
-    // Adjust until we're close to maxTokens
-    while (sliceTokens > maxTokens && estimatedEnd > charPos + 100) {
-      // Reduce by ~10%
-      estimatedEnd = charPos + Math.floor((estimatedEnd - charPos) * 0.9);
-      slice = content.slice(charPos, estimatedEnd);
-      sliceTokens = await llm.countTokens(slice);
-    }
+  while (tokenPos < totalTokens) {
+    const chunkEnd = Math.min(tokenPos + maxTokens, totalTokens);
+    const chunkTokens = allTokens.slice(tokenPos, chunkEnd);
+    let chunkText = await llm.detokenize(chunkTokens);
 
-    // If we're under, try to expand (but not past content end)
-    while (sliceTokens < maxTokens * 0.9 && estimatedEnd < content.length) {
-      const newEnd = Math.min(estimatedEnd + Math.floor((estimatedEnd - charPos) * 0.1), content.length);
-      if (newEnd === estimatedEnd) break;
-      const newSlice = content.slice(charPos, newEnd);
-      const newTokens = await llm.countTokens(newSlice);
-      if (newTokens > maxTokens) break;
-      estimatedEnd = newEnd;
-      slice = newSlice;
-      sliceTokens = newTokens;
-    }
-
-    // Find a good break point in the last 30% of the chunk
-    if (estimatedEnd < content.length) {
-      const searchStart = charPos + Math.floor((estimatedEnd - charPos) * 0.7);
-      const searchSlice = content.slice(searchStart, estimatedEnd);
+    // Find a good break point if not at end of document
+    if (chunkEnd < totalTokens) {
+      const searchStart = Math.floor(chunkText.length * 0.7);
+      const searchSlice = chunkText.slice(searchStart);
 
       let breakOffset = -1;
       const paragraphBreak = searchSlice.lastIndexOf('\n\n');
@@ -1120,34 +1099,24 @@ export async function chunkDocumentByTokens(
           const lineBreak = searchSlice.lastIndexOf('\n');
           if (lineBreak >= 0) {
             breakOffset = lineBreak + 1;
-          } else {
-            const spaceBreak = searchSlice.lastIndexOf(' ');
-            if (spaceBreak >= 0) {
-              breakOffset = spaceBreak + 1;
-            }
           }
         }
       }
 
       if (breakOffset >= 0) {
-        estimatedEnd = searchStart + breakOffset;
-        slice = content.slice(charPos, estimatedEnd);
-        sliceTokens = await llm.countTokens(slice);
+        chunkText = chunkText.slice(0, searchStart + breakOffset);
       }
     }
 
-    chunks.push({ text: slice, pos: charPos, tokens: sliceTokens });
+    // Approximate character position based on token position
+    const charPos = Math.floor(tokenPos * avgCharsPerToken);
+    chunks.push({ text: chunkText, pos: charPos, tokens: chunkTokens.length });
 
-    // Move forward with overlap
-    if (estimatedEnd >= content.length) break;
+    // Move forward
+    if (chunkEnd >= totalTokens) break;
 
-    // Calculate overlap in characters based on token ratio
-    const overlapChars = Math.floor(overlapTokens * (slice.length / sliceTokens));
-    charPos = estimatedEnd - overlapChars;
-    const lastChunkPos = chunks.at(-1)!.pos;
-    if (charPos <= lastChunkPos) {
-      charPos = estimatedEnd;  // Prevent infinite loop
-    }
+    // Advance by step tokens (maxTokens - overlap)
+    tokenPos += step;
   }
 
   return chunks;