Skip to main content

Long Context Embedding aka Late Chunking

Based on the Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models paper. This paper proposes a technique to add the document-level context information to the individual chunks.

Code based on the Late Chunking blog post.

All credits to Jina!

This notebooks explains how the Long Context Embedding can be implemented with LangChain.

Notes:

  • [Opionated!] This notebook uses the Long Context Embedding term which is more suitable than Late Chunking term.
  • Text chunking term used in the paper was replaced with the text splitting term that is used in LangChain.

Set up​

!pip install -U transformers

Load a model which we want to use for the embedding. We choose jinaai/jina-embeddings-v2-base-en but any other model which supports mean pooling is possible. Models with a large maximum context-length are preferred for the long context embedding.

from transformers import AutoModel, AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)

An illustration of the lost context problem. Here is a Wikipedia article about Berlin. One can see that phrases like β€œits” and β€œthe city” reference β€œBerlin,” which is mentioned only in the first sentence. This makes it harder for the embedding model to link these references to the correct entity, thereby producing a lower-quality vector representation.

input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

Text splitting​

Naive text splitting​

We split text by the sentence separators. We use . character as a separator. We save separators as a part of chunks.

In real life, we would use more robust and soficticated text splitter.

chunks = input_text.split(". ")

# take care of the separator at the end of the text:
naive_chunks = [(chunk + ".").replace("..", ".") for chunk in chunks]
print(chunks)
['Berlin is the capital and largest city of Germany, both by area and by population', "Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits", 'The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.']

Traditional chunk embedding​

traditional_embeddings = model.encode(chunks)
print(
f"Number of chunks: {len(traditional_embeddings)}, Embedding dimensions: {len(traditional_embeddings[0])}"
)
print(f"Embedding sample[:10]: {traditional_embeddings[0][:10]}...")
Number of chunks: 3, Embedding dimensions: 768
Embedding sample[:10]: [-0.7992611 -0.67268556 0.9821002 0.28078204 -0.08286519 0.0186394
0.14283076 0.13469528 0.14336902 -0.04381512]...

Long Context chunk embedding​

Pseudo-code:

  • For each chunk in text:
    • For each individual word in chunk:
      • Get word_context = (all text before the word) + the word
      • Calculate embedding for word as for the word_context # Now this embedding includes all previous text as the word context
    • Calculate chunk embedding as the average of the chunk word embeddings.

Note:

  • We do not limit of the word context, we use all text that preceed the word. The paper mentioned that we could use this limit as a hyperparameter.
import numpy as np


def calc_chunk_long_context_embeddings(chunks, model):
chunks_embeddings = []
left_words = []
for chunk in chunks:
chunk_subchunks = []
for word in chunk.strip().split(" "):
left_words.append(word)
chunk_subchunks.append(" ".join(left_words))
chunk_embeddings = model.encode(chunk_subchunks)
chunk_embeddings_avg = np.mean(chunk_embeddings, axis=0)
chunks_embeddings.append(chunk_embeddings_avg)
return chunks_embeddings


long_context_embeddings = calc_chunk_long_context_embeddings(chunks, model)
print(
f"Number of chunks: {len(chunks)}, long_context_embeddings lenght: {len(long_context_embeddings)}, Embedding dimensions: {len(long_context_embeddings[0])}"
)
print(f"Embedding sample[:10]: {long_context_embeddings[0][:10]}...")
Number of chunks: 3, long_context_embeddings lenght: 3, Embedding dimensions: 768
Embedding sample[:10]: [-0.6337145 -0.67458665 0.84166676 0.31604436 -0.21064001 0.24656577
0.12800963 0.03115163 0.15378839 0.05417303]...

Evaluation​

Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the long context method.

import numpy as np


def cos_sim(x, y):
return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


berlin_embedding = model.encode("Berlin")

for chunk, long_context_embedding, traditional_embedding in zip(
chunks, long_context_embeddings, traditional_embeddings
):
print()
print(f"Similarity of 'Berlin' vs '{chunk}':")
print(
f" long context embedding: {cos_sim(berlin_embedding, long_context_embedding):.3f}"
)
print(
f" traditional embedding: {cos_sim(berlin_embedding, traditional_embedding):.3f}"
)

Similarity of 'Berlin' vs 'Berlin is the capital and largest city of Germany, both by area and by population':
long context embedding: 0.900
traditional embedding: 0.838

Similarity of 'Berlin' vs 'Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits':
long context embedding: 0.862
traditional embedding: 0.704

Similarity of 'Berlin' vs 'The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.':
long context embedding: 0.859
traditional embedding: 0.753

As you can see the long context method helps in this case.


Was this page helpful?


You can also leave detailed feedback on GitHub.