Long Context Embedding aka Late Chunking
Based on the Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models paper. This paper proposes a technique to add the document-level context information to the individual chunks.
Code based on the Late Chunking blog post.
All credits to Jina!
This notebooks explains how the Long Context Embedding
can be implemented with LangChain
.
Notes:
- [Opionated!] This notebook uses the
Long Context Embedding
term which is more suitable thanLate Chunking
term. Text chunking
term used in the paper was replaced with thetext splitting
term that is used in LangChain.
Set upβ
!pip install -U transformers
Load a model which we want to use for the embedding. We choose jinaai/jina-embeddings-v2-base-en
but any other model which supports mean pooling is possible. Models with a large maximum context-length are preferred for the long context embedding.
from transformers import AutoModel, AutoTokenizer
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)
An illustration of the lost context problem. Here is a Wikipedia article about Berlin
.
One can see that phrases like βitsβ and βthe cityβ reference βBerlin,β which is mentioned
only in the first sentence. This makes it harder for the embedding model to link these references to
the correct entity, thereby producing a lower-quality vector representation.
input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."
Text splittingβ
Naive text splittingβ
We split text by the sentence separators. We use .
character as a separator.
We save separators as a part of chunks.
In real life, we would use more robust and soficticated text splitter.
chunks = input_text.split(". ")
# take care of the separator at the end of the text:
naive_chunks = [(chunk + ".").replace("..", ".") for chunk in chunks]
print(chunks)
['Berlin is the capital and largest city of Germany, both by area and by population', "Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits", 'The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.']
Traditional chunk embeddingβ
traditional_embeddings = model.encode(chunks)
print(
f"Number of chunks: {len(traditional_embeddings)}, Embedding dimensions: {len(traditional_embeddings[0])}"
)
print(f"Embedding sample[:10]: {traditional_embeddings[0][:10]}...")
Number of chunks: 3, Embedding dimensions: 768
Embedding sample[:10]: [-0.7992611 -0.67268556 0.9821002 0.28078204 -0.08286519 0.0186394
0.14283076 0.13469528 0.14336902 -0.04381512]...
Long Context chunk embeddingβ
Pseudo-code:
- For each chunk in text:
- For each individual word in chunk:
- Get word_context = (all text before the word) + the word
- Calculate embedding for word as for the word_context # Now this embedding includes all previous text as the word context
- Calculate chunk embedding as the average of the chunk word embeddings.
- For each individual word in chunk:
Note:
- We do not limit of the word context, we use all text that preceed the word. The paper mentioned that we could use this limit as a hyperparameter.
import numpy as np
def calc_chunk_long_context_embeddings(chunks, model):
chunks_embeddings = []
left_words = []
for chunk in chunks:
chunk_subchunks = []
for word in chunk.strip().split(" "):
left_words.append(word)
chunk_subchunks.append(" ".join(left_words))
chunk_embeddings = model.encode(chunk_subchunks)
chunk_embeddings_avg = np.mean(chunk_embeddings, axis=0)
chunks_embeddings.append(chunk_embeddings_avg)
return chunks_embeddings
long_context_embeddings = calc_chunk_long_context_embeddings(chunks, model)
print(
f"Number of chunks: {len(chunks)}, long_context_embeddings lenght: {len(long_context_embeddings)}, Embedding dimensions: {len(long_context_embeddings[0])}"
)
print(f"Embedding sample[:10]: {long_context_embeddings[0][:10]}...")
Number of chunks: 3, long_context_embeddings lenght: 3, Embedding dimensions: 768
Embedding sample[:10]: [-0.6337145 -0.67458665 0.84166676 0.31604436 -0.21064001 0.24656577
0.12800963 0.03115163 0.15378839 0.05417303]...
Evaluationβ
Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the long context method.
import numpy as np
def cos_sim(x, y):
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
berlin_embedding = model.encode("Berlin")
for chunk, long_context_embedding, traditional_embedding in zip(
chunks, long_context_embeddings, traditional_embeddings
):
print()
print(f"Similarity of 'Berlin' vs '{chunk}':")
print(
f" long context embedding: {cos_sim(berlin_embedding, long_context_embedding):.3f}"
)
print(
f" traditional embedding: {cos_sim(berlin_embedding, traditional_embedding):.3f}"
)
Similarity of 'Berlin' vs 'Berlin is the capital and largest city of Germany, both by area and by population':
long context embedding: 0.900
traditional embedding: 0.838
Similarity of 'Berlin' vs 'Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits':
long context embedding: 0.862
traditional embedding: 0.704
Similarity of 'Berlin' vs 'The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.':
long context embedding: 0.859
traditional embedding: 0.753
As you can see the long context method helps in this case.