import logging
import weaviate
from weaviate.classes.init import Auth
from weaviate.classes.config import Property, DataType
import weaviate.classes as wvc
from weaviate.classes.query import MetadataQuery
from typing import List, Optional, Tuple, Union
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
[docs]
class Weaviate(BaseVectorStore):
def __init__(
self,
embedding_model: Union[str, List[dict]],
collection_name: str,
embedding_batch: int = 100,
similarity_metric: str = "cosine",
client_type: str = "docker",
host: str = "localhost",
port: int = 8080,
grpc_port: int = 50051,
url: Optional[str] = None,
api_key: Optional[str] = None,
text_key: str = "content",
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
self.text_key = text_key
try:
if client_type == "docker":
self.client = weaviate.connect_to_local(
host=host,
port=port,
grpc_port=grpc_port,
)
elif client_type == "cloud":
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=url,
auth_credentials=Auth.api_key(api_key),
)
else:
raise ValueError(
f"client_type {client_type} is not supported\n"
"supported client types are: docker, cloud"
)
except Exception as exc:
logger.warning(
"Falling back to in-memory Weaviate store because the configured service is unavailable: %s",
exc,
)
self._enable_in_memory_store(store_key=f"weaviate:{collection_name}")
return
if similarity_metric == "cosine":
distance_metric = wvc.config.VectorDistances.COSINE
elif similarity_metric == "ip":
distance_metric = wvc.config.VectorDistances.DOT
elif similarity_metric == "l2":
distance_metric = wvc.config.VectorDistances.L2_SQUARED
else:
raise ValueError(
f"similarity_metric {similarity_metric} is not supported\n"
"supported similarity metrics are: cosine, ip, l2"
)
try:
if not self.client.collections.exists(collection_name):
self.client.collections.create(
collection_name,
properties=[
Property(
name="content",
data_type=DataType.TEXT,
skip_vectorization=True,
),
],
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( # hnsw, flat, dynamic,
distance_metric=distance_metric
),
)
self.collection = self.client.collections.get(collection_name)
self.collection_name = collection_name
except Exception as exc:
logger.warning(
"Falling back to in-memory Weaviate store because collection setup failed: %s",
exc,
)
self._enable_in_memory_store(store_key=f"weaviate:{collection_name}")
[docs]
async def add(self, ids: List[str], texts: List[str]):
if self._using_in_memory_store():
await self._in_memory_add(ids, texts)
return
texts = self.truncated_inputs(texts)
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
with self.client.batch.dynamic() as batch:
for i, text in enumerate(texts):
data_properties = {self.text_key: text}
batch.add_object(
collection=self.collection_name,
properties=data_properties,
uuid=ids[i],
vector=text_embeddings[i],
)
failed_objs = self.client.batch.failed_objects
for obj in failed_objs:
err_message = (
f"Failed to add object: {obj.original_uuid}\nReason: {obj.message}"
)
logger.error(err_message)
[docs]
def add_embedding(self, ids: List[str], embeddings: List[List[float]]):
if self._using_in_memory_store():
self._in_memory_add_embedding(ids, embeddings)
return
with self.client.batch.dynamic() as batch:
for i in range(len(ids)):
batch.add_object(
collection=self.collection_name,
uuid=ids[i],
vector=embeddings[i],
)
failed_objs = self.client.batch.failed_objects
for obj in failed_objs:
err_message = (
f"Failed to add object: {obj.original_uuid}\nReason: {obj.message}"
)
logger.error(err_message)
[docs]
async def fetch(self, ids: List[str]) -> List[List[float]]:
if self._using_in_memory_store():
return await self._in_memory_fetch(ids)
# Fetch vectors by IDs
results = self.collection.query.fetch_objects(
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
include_vector=True,
)
id_vector_dict = {
str(object.uuid): object.vector["default"] for object in results.objects
}
result = [id_vector_dict[_id] for _id in ids]
return result
[docs]
async def is_exist(self, ids: List[str]) -> List[bool]:
if self._using_in_memory_store():
return await self._in_memory_is_exist(ids)
fetched_result = self.collection.query.fetch_objects(
filters=wvc.query.Filter.by_property("_id").contains_any(ids),
)
existed_ids = [str(result.uuid) for result in fetched_result.objects]
return list(map(lambda x: x in existed_ids, ids))
[docs]
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
if self._using_in_memory_store():
return await self._in_memory_query(queries, top_k)
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
ids, scores = [], []
for query_embedding in query_embeddings:
response = self.collection.query.near_vector(
near_vector=query_embedding,
limit=top_k,
return_metadata=MetadataQuery(distance=True),
)
ids.append([o.uuid for o in response.objects])
scores.append(
[
distance_to_score(o.metadata.distance, self.similarity_metric)
for o in response.objects
]
)
return ids, scores
[docs]
async def delete(self, ids: List[str]):
if self._using_in_memory_store():
await self._in_memory_delete(ids)
return
filter = wvc.query.Filter.by_id().contains_any(ids)
self.collection.data.delete_many(where=filter)
[docs]
def delete_collection(self):
if self._using_in_memory_store():
self._in_memory_delete_collection()
return
# Delete the collection
self.client.collections.delete(self.collection_name)
[docs]
def distance_to_score(distance: float, similarity_metric) -> float:
if similarity_metric == "cosine":
return 1 - distance
elif similarity_metric == "ip":
return -distance
elif similarity_metric == "l2":
return -distance
else:
raise ValueError(
f"similarity_metric {similarity_metric} is not supported\n"
"supported similarity metrics are: cosine, ip, l2"
)