import numpy as np
import pyarrow as pa
import lancedb
from pylate import models
# 1) Load a late-interaction model via PyLate
# PyLate docs show ColBERT() + encode(..., is_query=...) :contentReference[oaicite:2]{index=2}
model = models.ColBERT(model_name_or_path="lightonai/GTE-ModernColBERT-v1")
# You can discover dim from one embedding (avoid guessing)
dim = model.encode(["hello"], is_query=True)[0].shape[1]
# 2) Create a LanceDB table with a multivector column
db = lancedb.connect("./pylate_lancedb")
schema = pa.schema([
pa.field("doc_id", pa.string()),
pa.field("text", pa.string()),
# multivector: list<list<float32, dim>> :contentReference[oaicite:3]{index=3}
pa.field("mv", pa.list_(pa.list_(pa.float32(), dim))),
])
docs = [
{"doc_id": "1", "text": "The train to Tokyo leaves at 5pm."},
{"doc_id": "2", "text": "That Pho restaurant in Hanoi is highly rated."},
{"doc_id": "3", "text": "This is a noodle bar in Osaka, Japan."},
]
# 3) Encode documents with PyLate (token vectors per doc)
doc_texts = [d["text"] for d in docs]
doc_embs = model.encode(doc_texts, is_query=False) # list/array of (T, dim) per doc :contentReference[oaicite:4]{index=4}
rows = []
for d, emb in zip(docs, doc_embs):
emb = np.asarray(emb, dtype=np.float32)
rows.append({**d, "mv": emb.tolist()})
tbl = db.create_table("docs", data=rows, schema=schema, mode="overwrite")
# 4) If your dataset is large, build an index + query using a query matrix
# For small datasets < 100k records, you can skip indexing
# tbl.create_index(vector_column_name="mv", metric="cosine")
query = "Tell me about ramen in Japan"
q_emb = np.asarray(model.encode([query], is_query=True)[0], dtype=np.float32) # (Tq, dim) :contentReference[oaicite:5]{index=5}
out = tbl.search(q_emb).limit(5).to_pandas() # multivector search accepts a matrix :contentReference[oaicite:6]{index=6}
print(out[["doc_id", "text"]])