From bab5611c944f85dfb6c121afb9451cadb8a09113 Mon Sep 17 00:00:00 2001 From: HeYQ Date: Sun, 27 Oct 2024 21:46:23 +0800 Subject: [PATCH 1/3] update: analyticdb postgresql vector --- spring-ai-alibaba-core/pom.xml | 30 ++ .../ai/autoconfigure/redis/GetBeanUtil.java | 58 +++ .../redis/RedisClientService.java | 49 +++ .../ai/autoconfigure/redis/RedisConfig.java | 29 ++ .../ai/dashscope/rag/AnalyticdbConfig.java | 157 +++++++ .../ai/dashscope/rag/AnalyticdbVector.java | 411 ++++++++++++++++++ 6 files changed, 734 insertions(+) create mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java create mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java create mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java create mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java create mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java diff --git a/spring-ai-alibaba-core/pom.xml b/spring-ai-alibaba-core/pom.xml index aaf42997..97f3ca5d 100644 --- a/spring-ai-alibaba-core/pom.xml +++ b/spring-ai-alibaba-core/pom.xml @@ -75,6 +75,36 @@ 4.12.0 + + com.aliyun + gpdb20160503 + 3.0.0 + + + + com.alibaba + fastjson + 1.2.75 + + + + + + + + + + + + + + + + org.springframework.boot + spring-boot-starter-data-redis + + + org.springframework.ai diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java new file mode 100644 index 00000000..edb891d8 --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java @@ -0,0 +1,58 @@ +package com.alibaba.cloud.ai.autoconfigure.redis; + +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.stereotype.Component; + +/** + * @author HeYQ + * @version 1.0 + * @date 2024-10-27 17:12 + * @describe + */ +@Component +public class GetBeanUtil implements ApplicationContextAware { + private static ApplicationContext applicationContext = null; + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + if (GetBeanUtil.applicationContext == null) { + GetBeanUtil.applicationContext = applicationContext; + } + } + + /** + * @return ApplicationContext + */ + public static ApplicationContext getApplicationContext() { + return applicationContext; + } + + /** + * @param beanName beanName + * @return bean + */ + public static Object getBean(String beanName) { + return applicationContext.getBean(beanName); + } + + /** + * @param c c + * @param 泛型 + * @return bean + */ + public static T getBean(Class c) { + return applicationContext.getBean(c); + } + + /** + * @param c c + * @param name + * @param + * @return T + */ + public static T getBean(String name, Class c) { + return getApplicationContext().getBean(name, c); + } +} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java new file mode 100644 index 00000000..6fa27309 --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java @@ -0,0 +1,49 @@ + +package com.alibaba.cloud.ai.autoconfigure.redis; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.stereotype.Component; +import java.util.concurrent.TimeUnit; + +/** + * Redis client
+ * @author HeYQ + * @version + */ +@Component +public class RedisClientService { + + @Autowired + RedisTemplate redisTemplate; + + /** + * @param key + * @param releaseTime + * @return + */ + public boolean lock(String key, long releaseTime) { + // 尝试获取锁 + Boolean boo = redisTemplate.opsForValue().setIfAbsent(key, "0", releaseTime, TimeUnit.SECONDS); + // 判断结果 + return boo != null && boo; + } + + /** + * @param key + */ + public void deleteLock(String key) { + // 删除key即可释放锁 + redisTemplate.delete(key); + } + + public void setKeyValue(String key, String value, int timeout) { + // Set the cache key to indicate the collection exists + redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS); + } + + public String getValueByKey(String key) { + return (String) redisTemplate.opsForValue().get(key); + } + +} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java new file mode 100644 index 00000000..5cc6fdcf --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java @@ -0,0 +1,29 @@ +package com.alibaba.cloud.ai.autoconfigure.redis; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.serializer.StringRedisSerializer; + +/** + * @author HeYQ + * @version 1.0 + * @date 2024-10-27 17:36 + * @describe + */ +@Configuration +public class RedisConfig { + + @Bean + public RedisTemplate redisTemplate(RedisConnectionFactory factory) { + RedisTemplate template = new RedisTemplate<>(); + template.setConnectionFactory(factory); + + template.setKeySerializer(new StringRedisSerializer()); + template.setValueSerializer(new StringRedisSerializer()); + template.afterPropertiesSet(); + + return template; + } +} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java new file mode 100644 index 00000000..beab2833 --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java @@ -0,0 +1,157 @@ +package com.alibaba.cloud.ai.dashscope.rag; + +/** + * @author HeYQ + * @version 1.0 + * @date 2024-10-23 20:22 + * @describe + */ +import com.fasterxml.jackson.annotation.JsonInclude; +import java.util.HashMap; +import java.util.Map; + +@JsonInclude(JsonInclude.Include.NON_NULL) +public class AnalyticdbConfig { + + private String accessKeyId; + + private String accessKeySecret; + + private String regionId; + + private String DBInstanceId; + + private String managerAccount; + + private String managerAccountPassword; + + private String namespace; + + private String namespacePassword; + + private String metrics = "cosine"; + + private Integer readTimeout = 60000; + + private Long embeddingDimension = 1536L; + + public AnalyticdbConfig() { + + } + + + public AnalyticdbConfig(String accessKeyId, String accessKeySecret, String regionId, String DBInstanceId, + String managerAccount, String managerAccountPassword, + String namespace, String namespacePassword, + String metrics, Integer readTimeout, Long embeddingDimension) { + this.accessKeyId = accessKeyId; + this.accessKeySecret = accessKeySecret; + this.regionId = regionId; + this.DBInstanceId = DBInstanceId; + this.managerAccount = managerAccount; + this.managerAccountPassword = managerAccountPassword; + this.namespace = namespace; + this.namespacePassword = namespacePassword; + this.metrics = metrics; + this.readTimeout = readTimeout; + this.embeddingDimension = embeddingDimension; + } + + + public Map toAnalyticdbClientParams() { + Map params = new HashMap<>(); + params.put("accessKeyId", this.accessKeyId); + params.put("accessKeySecret", this.accessKeySecret); + params.put("regionId", this.regionId); + params.put("readTimeout", this.readTimeout); + return params; + } + + public String getAccessKeyId() { + return accessKeyId; + } + + public void setAccessKeyId(String accessKeyId) { + this.accessKeyId = accessKeyId; + } + + public String getAccessKeySecret() { + return accessKeySecret; + } + + public void setAccessKeySecret(String accessKeySecret) { + this.accessKeySecret = accessKeySecret; + } + + public String getRegionId() { + return regionId; + } + + public void setRegionId(String regionId) { + this.regionId = regionId; + } + + public String getDBInstanceId() { + return DBInstanceId; + } + + public void setDBInstanceId(String DBInstanceId) { + this.DBInstanceId = DBInstanceId; + } + + public String getManagerAccount() { + return managerAccount; + } + + public void setManagerAccount(String managerAccount) { + this.managerAccount = managerAccount; + } + + public String getManagerAccountPassword() { + return managerAccountPassword; + } + + public void setManagerAccountPassword(String managerAccountPassword) { + this.managerAccountPassword = managerAccountPassword; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + public String getNamespacePassword() { + return namespacePassword; + } + + public void setNamespacePassword(String namespacePassword) { + this.namespacePassword = namespacePassword; + } + + public String getMetrics() { + return metrics; + } + + public void setMetrics(String metrics) { + this.metrics = metrics; + } + + public Integer getReadTimeout() { + return readTimeout; + } + + public void setReadTimeout(Integer readTimeout) { + this.readTimeout = readTimeout; + } + + public Long getEmbeddingDimension() { + return embeddingDimension; + } + + public void setEmbeddingDimension(Long embeddingDimension) { + this.embeddingDimension = embeddingDimension; + } +} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java new file mode 100644 index 00000000..b4510b43 --- /dev/null +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java @@ -0,0 +1,411 @@ +package com.alibaba.cloud.ai.dashscope.rag; + +import com.alibaba.cloud.ai.autoconfigure.redis.GetBeanUtil; +import com.alibaba.cloud.ai.autoconfigure.redis.RedisClientService; +import com.alibaba.fastjson.JSONObject; +import com.aliyun.gpdb20160503.Client; +import com.aliyun.gpdb20160503.models.CreateCollectionRequest; +import com.aliyun.gpdb20160503.models.CreateNamespaceRequest; +import com.aliyun.gpdb20160503.models.DeleteCollectionDataRequest; +import com.aliyun.gpdb20160503.models.DeleteCollectionDataResponse; +import com.aliyun.gpdb20160503.models.DeleteCollectionRequest; +import com.aliyun.gpdb20160503.models.DescribeCollectionRequest; +import com.aliyun.gpdb20160503.models.DescribeNamespaceRequest; +import com.aliyun.gpdb20160503.models.InitVectorDatabaseRequest; +import com.aliyun.gpdb20160503.models.QueryCollectionDataRequest; +import com.aliyun.gpdb20160503.models.QueryCollectionDataResponse; +import com.aliyun.gpdb20160503.models.QueryCollectionDataResponseBody; +import com.aliyun.gpdb20160503.models.UpsertCollectionDataRequest; +import com.aliyun.teaopenapi.models.Config; +import com.aliyun.tea.TeaException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSON; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * @author HeYQ + * @version 1.0 + * @ date 2024-10-23 20:29 + * @ describe + */ +public class AnalyticdbVector implements VectorStore { + + private static final Logger logger = LoggerFactory.getLogger(AnalyticdbVector.class); + + private static volatile AnalyticdbVector instance; + + private static Boolean initialized = false; + + private final String collectionName; + + private AnalyticdbConfig config; + + private Client client; + + private RedisClientService redisClientService ; + + public static AnalyticdbVector getInstance(String collectionName, AnalyticdbConfig config) throws Exception { + if (instance == null) { + synchronized (AnalyticdbVector.class) { + if (instance == null) { + instance = new AnalyticdbVector(collectionName, config); + } + } + } + return instance; + } + + private AnalyticdbVector(String collectionName, AnalyticdbConfig config) throws Exception { + // collection_name must be updated every time + this.collectionName = collectionName.toLowerCase(); + if (AnalyticdbVector.initialized) { + return; + } + this.config = config; + this.redisClientService = GetBeanUtil.getBean(RedisClientService.class); + try{ + Config clientConfig = Config.build(this.config.toAnalyticdbClientParams()); + this.client = new Client(clientConfig); + }catch(Exception e){ + logger.debug("create Analyticdb client error", e); + } + initialize(); + AnalyticdbVector.initialized = Boolean.TRUE; + } + + /** + * initialize vector db + */ + private void initialize() throws Exception { + initializeVectorDataBase(); + createNameSpaceIfNotExists(); + } + + private void initializeVectorDataBase() { + InitVectorDatabaseRequest request = new InitVectorDatabaseRequest() + .setDBInstanceId(config.getDBInstanceId()) + .setRegionId(config.getRegionId()) + .setManagerAccount(config.getManagerAccount()) + .setManagerAccountPassword(config.getManagerAccountPassword()); + try{ + client.initVectorDatabase(request); + }catch(Exception e){ + logger.error("init Vector data base error",e); + } + } + + private void createNameSpaceIfNotExists() throws Exception { + try{ + DescribeNamespaceRequest request = new DescribeNamespaceRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()); + this.client.describeNamespace(request); + }catch(TeaException e){ + if (Objects.equals(e.getStatusCode(), 404)){ + CreateNamespaceRequest request = new CreateNamespaceRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()) + .setNamespacePassword(this.config.getNamespacePassword()); + this.client.createNamespace(request); + }else { + throw new Exception("failed to create namespace:{}", e); + } + } + } + + private void createCollectionIfNotExists(Long embeddingDimension) { + String cacheKey = "vector_indexing_" + this.collectionName; + String lockName = cacheKey + "_lock"; + try { + boolean lock = redisClientService.lock(lockName, 20);// Acquire the lock + if(lock) { + // Check if the collection exists in the cache + if ("1".equals(redisClientService.getValueByKey(cacheKey))) { + redisClientService.deleteLock(lockName); + return; + } + // Describe the collection to check if it exists + DescribeCollectionRequest describeRequest = new DescribeCollectionRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName); + try { + this.client.describeCollection(describeRequest); + } catch (TeaException e) { + if (e.getStatusCode() == 404) { + // Collection does not exist, create it + String metadata = JSON.toJSONString(new JSONObject() + .fluentPut("refDocId", "text") + .fluentPut("content", "text") + .fluentPut("metadata", "jsonb")); + + String fullTextRetrievalFields = "content"; + CreateCollectionRequest createRequest = new CreateCollectionRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()) + .setNamespace(this.config.getNamespace()) + .setCollection(this.collectionName) + .setDimension(embeddingDimension) + .setMetrics(this.config.getMetrics()) + .setMetadata(metadata) + .setFullTextRetrievalFields(fullTextRetrievalFields); + this.client.createCollection(createRequest); + } else { + throw new RuntimeException("Failed to create collection " + this.collectionName + ": " + e.getMessage()); + } + // Set the cache key to indicate the collection exists + redisClientService.setKeyValue(cacheKey, "1", 3600); + } + } + }catch(Exception e){ + redisClientService.deleteLock(lockName); + throw new RuntimeException("Failed to create collection " + this.collectionName + ": " + e.getMessage()); + }finally { + redisClientService.deleteLock(lockName); + } + } + + public void create(List texts, List> embeddings) throws Exception { + long dimension = embeddings.get(0).size(); + createCollectionIfNotExists(dimension); + addTexts(texts, embeddings); + } + + @Override + public void add(List texts) { + try { + createCollectionIfNotExists(this.config.getEmbeddingDimension()); + addTexts(texts); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public void addTexts(List documents, List> embeddings) throws Exception { + List rows = new ArrayList<>(10); + for (int i = 0; i < documents.size(); i++) { + Document doc = documents.get(i); + List embedding = embeddings.get(i); + + Map metadata = new HashMap<>(); + metadata.put("id", (String) doc.getMetadata().get("id")); + metadata.put("content", doc.getContent()); + metadata.put("metadata", JSON.toJSONString(doc.getMetadata())); + rows.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(embedding).setMetadata(metadata)); + } + UpsertCollectionDataRequest request = new UpsertCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setRows(rows); + this.client.upsertCollectionData(request); + } + + public void addTexts(List documents) throws Exception { + List rows = new ArrayList<>(10); + for (Document doc : documents) { + float[] floatEmbeddings = doc.getEmbedding(); + List embedding = new ArrayList<>(floatEmbeddings.length); + for (float floatEmbedding : floatEmbeddings) { + embedding.add((double) floatEmbedding); + } + Map metadata = new HashMap<>(); + metadata.put("refDocId", (String) doc.getMetadata().get("docId")); + metadata.put("content", doc.getContent()); + metadata.put("metadata", JSON.toJSONString(doc.getMetadata())); + rows.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(embedding).setMetadata(metadata)); + } + UpsertCollectionDataRequest request = new UpsertCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setRows(rows); + this.client.upsertCollectionData(request); + } + + public boolean textExists(String id) { + QueryCollectionDataRequest request = new QueryCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setMetrics(this.config.getMetrics()) + .setIncludeValues(true) + .setVector(null) + .setContent(null) + .setTopK(1L) + .setFilter("refDocId='" + id + "'"); + + try { + QueryCollectionDataResponse response = this.client.queryCollectionData(request); + return response.getBody().getMatches().getMatch().size() > 0; + } catch (Exception e) { + throw new RuntimeException("Failed to query collection data: " + e.getMessage(), e); + } + } + + @Override + public Optional delete(List ids) { + if (ids.isEmpty()) { + return Optional.of(false); + } + String idsStr = ids.stream() + .map(id -> "'" + id + "'") + .collect(Collectors.joining(", ", "(", ")")); + DeleteCollectionDataRequest request = new DeleteCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setCollectionData(null) + .setCollectionDataFilter("refDocId IN " + idsStr); + try { + DeleteCollectionDataResponse deleteCollectionDataResponse = this.client.deleteCollectionData(request); + return deleteCollectionDataResponse.statusCode.equals(200) ? Optional.of(true) : Optional.of(false); + // Handle response if needed + } catch (Exception e) { + throw new RuntimeException("Failed to delete collection data by IDs: " + e.getMessage(), e); + } + } + + public void deleteByMetadataField(String key, String value) { + DeleteCollectionDataRequest request = new DeleteCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setCollectionData(null) + .setCollectionDataFilter("metadata ->> '" + key + "' = '" + value + "'"); + try { + this.client.deleteCollectionData(request); + // Handle response if needed + } catch (Exception e) { + throw new RuntimeException("Failed to delete collection data by metadata field: " + e.getMessage(), e); + } + } + + + public List searchByVector(List queryVector, Map kwargs) { + Double scoreThreshold = (Double) kwargs.getOrDefault("scoreThreshold", 0.0d); + Boolean includeValues = (Boolean) kwargs.getOrDefault("includeValues", true); + Long topK = (Long) kwargs.getOrDefault("topK", 4); + + QueryCollectionDataRequest request = new QueryCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setIncludeValues(includeValues) + .setMetrics(this.config.getMetrics()) + .setVector(queryVector) + .setContent(null) + .setTopK(topK) + .setFilter(null); + + try { + QueryCollectionDataResponse response = this.client.queryCollectionData(request); + List documents = new ArrayList<>(); + for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch match : response.getBody().getMatches().getMatch()) { + if (match.getScore() != null && match.getScore() > scoreThreshold) { + Map metadata = match.getMetadata(); + String pageContent = metadata.get("content"); + Map metadataJson = JSONObject.parseObject(metadata.get("metadata"), HashMap.class); + Document doc = new Document(pageContent, metadataJson); + documents.add(doc); + } + } + return documents; + } catch (Exception e) { + throw new RuntimeException("Failed to search by vector: " + e.getMessage(), e); + } + } + + @Override + public List similaritySearch(String query) { + + return similaritySearch(SearchRequest.query(query)); + + } + + @Override + public List similaritySearch(SearchRequest searchRequest) { + double scoreThreshold = searchRequest.getSimilarityThreshold(); + boolean includeValues = searchRequest.hasFilterExpression(); + int topK = searchRequest.getTopK(); + + QueryCollectionDataRequest request = new QueryCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setIncludeValues(includeValues) + .setMetrics(this.config.getMetrics()) + .setVector(null) + .setContent(searchRequest.query) + .setTopK((long) topK) + .setFilter(null); + try { + QueryCollectionDataResponse response = this.client.queryCollectionData(request); + List documents = new ArrayList<>(); + for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch match : response.getBody().getMatches().getMatch()) { + if (match.getScore() != null && match.getScore() > scoreThreshold) { +// System.out.println(match.getScore()); + Map metadata = match.getMetadata(); + String pageContent = metadata.get("content"); + Map metadataJson = JSONObject.parseObject(metadata.get("metadata"), HashMap.class); + Document doc = new Document(pageContent, metadataJson); + documents.add(doc); + } + } + return documents; + } catch (Exception e) { + throw new RuntimeException("Failed to search by full text: " + e.getMessage(), e); + } + } + + public void deleteAll() { + DeleteCollectionRequest request = new DeleteCollectionRequest() + .setCollection(this.collectionName) + .setDBInstanceId(this.config.getDBInstanceId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setRegionId(this.config.getRegionId()); + + try { + this.client.deleteCollection(request); + } catch (Exception e) { + throw new RuntimeException("Failed to delete collection: " + e.getMessage(), e); + } + } + +} From 13eb116c459570c8e2c841b15135c56eab08170f Mon Sep 17 00:00:00 2001 From: HeYQ Date: Sun, 27 Oct 2024 21:47:09 +0800 Subject: [PATCH 2/3] update: analyticdb postgresql vector --- spring-ai-alibaba-core/pom.xml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/spring-ai-alibaba-core/pom.xml b/spring-ai-alibaba-core/pom.xml index 97f3ca5d..331d5b17 100644 --- a/spring-ai-alibaba-core/pom.xml +++ b/spring-ai-alibaba-core/pom.xml @@ -87,18 +87,6 @@ 1.2.75
- - - - - - - - - - - - org.springframework.boot spring-boot-starter-data-redis From 7c945dc698e539f9c8d9c1b878e3b161cf75f53c Mon Sep 17 00:00:00 2001 From: HeYQ Date: Mon, 28 Oct 2024 16:10:20 +0800 Subject: [PATCH 3/3] update: analyticdb postgresql vector --- spring-ai-alibaba-core/pom.xml | 6 - .../ai/autoconfigure/redis/GetBeanUtil.java | 58 -- .../redis/RedisClientService.java | 49 -- .../ai/autoconfigure/redis/RedisConfig.java | 29 - .../ai/dashscope/rag/AnalyticdbConfig.java | 253 ++++---- .../ai/dashscope/rag/AnalyticdbVector.java | 601 +++++++----------- 6 files changed, 367 insertions(+), 629 deletions(-) delete mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java delete mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java delete mode 100644 spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java diff --git a/spring-ai-alibaba-core/pom.xml b/spring-ai-alibaba-core/pom.xml index 331d5b17..2c77da6d 100644 --- a/spring-ai-alibaba-core/pom.xml +++ b/spring-ai-alibaba-core/pom.xml @@ -87,12 +87,6 @@ 1.2.75 - - org.springframework.boot - spring-boot-starter-data-redis - - - org.springframework.ai diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java deleted file mode 100644 index edb891d8..00000000 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/GetBeanUtil.java +++ /dev/null @@ -1,58 +0,0 @@ -package com.alibaba.cloud.ai.autoconfigure.redis; - -import org.springframework.beans.BeansException; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; -import org.springframework.stereotype.Component; - -/** - * @author HeYQ - * @version 1.0 - * @date 2024-10-27 17:12 - * @describe - */ -@Component -public class GetBeanUtil implements ApplicationContextAware { - private static ApplicationContext applicationContext = null; - - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - if (GetBeanUtil.applicationContext == null) { - GetBeanUtil.applicationContext = applicationContext; - } - } - - /** - * @return ApplicationContext - */ - public static ApplicationContext getApplicationContext() { - return applicationContext; - } - - /** - * @param beanName beanName - * @return bean - */ - public static Object getBean(String beanName) { - return applicationContext.getBean(beanName); - } - - /** - * @param c c - * @param 泛型 - * @return bean - */ - public static T getBean(Class c) { - return applicationContext.getBean(c); - } - - /** - * @param c c - * @param name - * @param - * @return T - */ - public static T getBean(String name, Class c) { - return getApplicationContext().getBean(name, c); - } -} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java deleted file mode 100644 index 6fa27309..00000000 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisClientService.java +++ /dev/null @@ -1,49 +0,0 @@ - -package com.alibaba.cloud.ai.autoconfigure.redis; - -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.redis.core.RedisTemplate; -import org.springframework.stereotype.Component; -import java.util.concurrent.TimeUnit; - -/** - * Redis client
- * @author HeYQ - * @version - */ -@Component -public class RedisClientService { - - @Autowired - RedisTemplate redisTemplate; - - /** - * @param key - * @param releaseTime - * @return - */ - public boolean lock(String key, long releaseTime) { - // 尝试获取锁 - Boolean boo = redisTemplate.opsForValue().setIfAbsent(key, "0", releaseTime, TimeUnit.SECONDS); - // 判断结果 - return boo != null && boo; - } - - /** - * @param key - */ - public void deleteLock(String key) { - // 删除key即可释放锁 - redisTemplate.delete(key); - } - - public void setKeyValue(String key, String value, int timeout) { - // Set the cache key to indicate the collection exists - redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS); - } - - public String getValueByKey(String key) { - return (String) redisTemplate.opsForValue().get(key); - } - -} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java deleted file mode 100644 index 5cc6fdcf..00000000 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/autoconfigure/redis/RedisConfig.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.alibaba.cloud.ai.autoconfigure.redis; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.data.redis.connection.RedisConnectionFactory; -import org.springframework.data.redis.core.RedisTemplate; -import org.springframework.data.redis.serializer.StringRedisSerializer; - -/** - * @author HeYQ - * @version 1.0 - * @date 2024-10-27 17:36 - * @describe - */ -@Configuration -public class RedisConfig { - - @Bean - public RedisTemplate redisTemplate(RedisConnectionFactory factory) { - RedisTemplate template = new RedisTemplate<>(); - template.setConnectionFactory(factory); - - template.setKeySerializer(new StringRedisSerializer()); - template.setValueSerializer(new StringRedisSerializer()); - template.afterPropertiesSet(); - - return template; - } -} diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java index beab2833..c8197a96 100644 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbConfig.java @@ -3,8 +3,7 @@ /** * @author HeYQ * @version 1.0 - * @date 2024-10-23 20:22 - * @describe + * @since 2024-10-23 20:22 */ import com.fasterxml.jackson.annotation.JsonInclude; import java.util.HashMap; @@ -13,145 +12,167 @@ @JsonInclude(JsonInclude.Include.NON_NULL) public class AnalyticdbConfig { - private String accessKeyId; + private String accessKeyId; - private String accessKeySecret; + private String accessKeySecret; - private String regionId; + private String regionId; - private String DBInstanceId; + private String DBInstanceId; - private String managerAccount; + private String managerAccount; - private String managerAccountPassword; + private String managerAccountPassword; - private String namespace; + private String namespace; - private String namespacePassword; + private String namespacePassword; - private String metrics = "cosine"; + private String metrics = "cosine"; - private Integer readTimeout = 60000; + private Integer readTimeout = 60000; - private Long embeddingDimension = 1536L; + private Long embeddingDimension = 1536L; - public AnalyticdbConfig() { + private String userAgent = "index"; - } + public AnalyticdbConfig() { + } - public AnalyticdbConfig(String accessKeyId, String accessKeySecret, String regionId, String DBInstanceId, - String managerAccount, String managerAccountPassword, - String namespace, String namespacePassword, - String metrics, Integer readTimeout, Long embeddingDimension) { - this.accessKeyId = accessKeyId; - this.accessKeySecret = accessKeySecret; - this.regionId = regionId; - this.DBInstanceId = DBInstanceId; - this.managerAccount = managerAccount; - this.managerAccountPassword = managerAccountPassword; - this.namespace = namespace; - this.namespacePassword = namespacePassword; - this.metrics = metrics; - this.readTimeout = readTimeout; - this.embeddingDimension = embeddingDimension; - } + public AnalyticdbConfig(String accessKeyId, String accessKeySecret, String regionId, String DBInstanceId, + String managerAccount, String managerAccountPassword, String namespace, String namespacePassword, + String metrics, Integer readTimeout, Long embeddingDimension, String userAgent) { + this.accessKeyId = accessKeyId; + this.accessKeySecret = accessKeySecret; + this.regionId = regionId; + this.DBInstanceId = DBInstanceId; + this.managerAccount = managerAccount; + this.managerAccountPassword = managerAccountPassword; + this.namespace = namespace; + this.namespacePassword = namespacePassword; + this.metrics = metrics; + this.readTimeout = readTimeout; + this.embeddingDimension = embeddingDimension; + this.userAgent = userAgent; + } + public Map toAnalyticdbClientParams() { + Map params = new HashMap<>(); + params.put("accessKeyId", this.accessKeyId); + params.put("accessKeySecret", this.accessKeySecret); + params.put("regionId", this.regionId); + params.put("readTimeout", this.readTimeout); + params.put("userAgent", this.userAgent); + return params; + } - public Map toAnalyticdbClientParams() { - Map params = new HashMap<>(); - params.put("accessKeyId", this.accessKeyId); - params.put("accessKeySecret", this.accessKeySecret); - params.put("regionId", this.regionId); - params.put("readTimeout", this.readTimeout); - return params; - } + public String getAccessKeyId() { + return accessKeyId; + } - public String getAccessKeyId() { - return accessKeyId; - } + public AnalyticdbConfig setAccessKeyId(String accessKeyId) { + this.accessKeyId = accessKeyId; + return this; + } - public void setAccessKeyId(String accessKeyId) { - this.accessKeyId = accessKeyId; - } + public String getAccessKeySecret() { + return accessKeySecret; + } - public String getAccessKeySecret() { - return accessKeySecret; - } + public AnalyticdbConfig setAccessKeySecret(String accessKeySecret) { + this.accessKeySecret = accessKeySecret; + return this; + } - public void setAccessKeySecret(String accessKeySecret) { - this.accessKeySecret = accessKeySecret; - } + public String getRegionId() { + return regionId; + } - public String getRegionId() { - return regionId; - } + public AnalyticdbConfig setRegionId(String regionId) { + this.regionId = regionId; + return this; + } - public void setRegionId(String regionId) { - this.regionId = regionId; - } + public String getDBInstanceId() { + return DBInstanceId; + } - public String getDBInstanceId() { - return DBInstanceId; - } + public AnalyticdbConfig setDBInstanceId(String DBInstanceId) { + this.DBInstanceId = DBInstanceId; + return this; + } - public void setDBInstanceId(String DBInstanceId) { - this.DBInstanceId = DBInstanceId; - } + public String getManagerAccount() { + return managerAccount; + } - public String getManagerAccount() { - return managerAccount; - } + public AnalyticdbConfig setManagerAccount(String managerAccount) { + this.managerAccount = managerAccount; + return this; + } - public void setManagerAccount(String managerAccount) { - this.managerAccount = managerAccount; - } + public String getManagerAccountPassword() { + return managerAccountPassword; + } + + public AnalyticdbConfig setManagerAccountPassword(String managerAccountPassword) { + this.managerAccountPassword = managerAccountPassword; + return this; + } + + public String getNamespace() { + return namespace; + } + + public AnalyticdbConfig setNamespace(String namespace) { + this.namespace = namespace; + return this; + } + + public String getNamespacePassword() { + return namespacePassword; + } + + public AnalyticdbConfig setNamespacePassword(String namespacePassword) { + this.namespacePassword = namespacePassword; + return this; + } + + public String getMetrics() { + return metrics; + } + + public AnalyticdbConfig setMetrics(String metrics) { + this.metrics = metrics; + return this; + } + + public Integer getReadTimeout() { + return readTimeout; + } + + public AnalyticdbConfig setReadTimeout(Integer readTimeout) { + this.readTimeout = readTimeout; + return this; + } + + public Long getEmbeddingDimension() { + return embeddingDimension; + } + + public AnalyticdbConfig setEmbeddingDimension(Long embeddingDimension) { + this.embeddingDimension = embeddingDimension; + return this; + } + + public String getUserAgent() { + return userAgent; + } + + public AnalyticdbConfig setUserAgent(String userAgent) { + this.userAgent = userAgent; + return this; + } - public String getManagerAccountPassword() { - return managerAccountPassword; - } - - public void setManagerAccountPassword(String managerAccountPassword) { - this.managerAccountPassword = managerAccountPassword; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - public String getNamespacePassword() { - return namespacePassword; - } - - public void setNamespacePassword(String namespacePassword) { - this.namespacePassword = namespacePassword; - } - - public String getMetrics() { - return metrics; - } - - public void setMetrics(String metrics) { - this.metrics = metrics; - } - - public Integer getReadTimeout() { - return readTimeout; - } - - public void setReadTimeout(Integer readTimeout) { - this.readTimeout = readTimeout; - } - - public Long getEmbeddingDimension() { - return embeddingDimension; - } - - public void setEmbeddingDimension(Long embeddingDimension) { - this.embeddingDimension = embeddingDimension; - } } diff --git a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java index b4510b43..dda9fcd0 100644 --- a/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java +++ b/spring-ai-alibaba-core/src/main/java/com/alibaba/cloud/ai/dashscope/rag/AnalyticdbVector.java @@ -1,17 +1,15 @@ package com.alibaba.cloud.ai.dashscope.rag; -import com.alibaba.cloud.ai.autoconfigure.redis.GetBeanUtil; -import com.alibaba.cloud.ai.autoconfigure.redis.RedisClientService; import com.alibaba.fastjson.JSONObject; import com.aliyun.gpdb20160503.Client; import com.aliyun.gpdb20160503.models.CreateCollectionRequest; import com.aliyun.gpdb20160503.models.CreateNamespaceRequest; import com.aliyun.gpdb20160503.models.DeleteCollectionDataRequest; import com.aliyun.gpdb20160503.models.DeleteCollectionDataResponse; -import com.aliyun.gpdb20160503.models.DeleteCollectionRequest; import com.aliyun.gpdb20160503.models.DescribeCollectionRequest; import com.aliyun.gpdb20160503.models.DescribeNamespaceRequest; import com.aliyun.gpdb20160503.models.InitVectorDatabaseRequest; +import com.aliyun.gpdb20160503.models.InitVectorDatabaseResponse; import com.aliyun.gpdb20160503.models.QueryCollectionDataRequest; import com.aliyun.gpdb20160503.models.QueryCollectionDataResponse; import com.aliyun.gpdb20160503.models.QueryCollectionDataResponseBody; @@ -36,376 +34,237 @@ /** * @author HeYQ * @version 1.0 - * @ date 2024-10-23 20:29 - * @ describe + * @since 2024-10-23 20:29 */ public class AnalyticdbVector implements VectorStore { - private static final Logger logger = LoggerFactory.getLogger(AnalyticdbVector.class); - - private static volatile AnalyticdbVector instance; - - private static Boolean initialized = false; - - private final String collectionName; - - private AnalyticdbConfig config; - - private Client client; - - private RedisClientService redisClientService ; - - public static AnalyticdbVector getInstance(String collectionName, AnalyticdbConfig config) throws Exception { - if (instance == null) { - synchronized (AnalyticdbVector.class) { - if (instance == null) { - instance = new AnalyticdbVector(collectionName, config); - } - } - } - return instance; - } - - private AnalyticdbVector(String collectionName, AnalyticdbConfig config) throws Exception { - // collection_name must be updated every time - this.collectionName = collectionName.toLowerCase(); - if (AnalyticdbVector.initialized) { - return; - } - this.config = config; - this.redisClientService = GetBeanUtil.getBean(RedisClientService.class); - try{ - Config clientConfig = Config.build(this.config.toAnalyticdbClientParams()); - this.client = new Client(clientConfig); - }catch(Exception e){ - logger.debug("create Analyticdb client error", e); - } - initialize(); - AnalyticdbVector.initialized = Boolean.TRUE; - } - - /** - * initialize vector db - */ - private void initialize() throws Exception { - initializeVectorDataBase(); - createNameSpaceIfNotExists(); - } - - private void initializeVectorDataBase() { - InitVectorDatabaseRequest request = new InitVectorDatabaseRequest() - .setDBInstanceId(config.getDBInstanceId()) - .setRegionId(config.getRegionId()) - .setManagerAccount(config.getManagerAccount()) - .setManagerAccountPassword(config.getManagerAccountPassword()); - try{ - client.initVectorDatabase(request); - }catch(Exception e){ - logger.error("init Vector data base error",e); - } - } - - private void createNameSpaceIfNotExists() throws Exception { - try{ - DescribeNamespaceRequest request = new DescribeNamespaceRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setManagerAccount(this.config.getManagerAccount()) - .setManagerAccountPassword(this.config.getManagerAccountPassword()); - this.client.describeNamespace(request); - }catch(TeaException e){ - if (Objects.equals(e.getStatusCode(), 404)){ - CreateNamespaceRequest request = new CreateNamespaceRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setManagerAccount(this.config.getManagerAccount()) - .setManagerAccountPassword(this.config.getManagerAccountPassword()) - .setNamespacePassword(this.config.getNamespacePassword()); - this.client.createNamespace(request); - }else { - throw new Exception("failed to create namespace:{}", e); - } - } - } - - private void createCollectionIfNotExists(Long embeddingDimension) { - String cacheKey = "vector_indexing_" + this.collectionName; - String lockName = cacheKey + "_lock"; - try { - boolean lock = redisClientService.lock(lockName, 20);// Acquire the lock - if(lock) { - // Check if the collection exists in the cache - if ("1".equals(redisClientService.getValueByKey(cacheKey))) { - redisClientService.deleteLock(lockName); - return; - } - // Describe the collection to check if it exists - DescribeCollectionRequest describeRequest = new DescribeCollectionRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName); - try { - this.client.describeCollection(describeRequest); - } catch (TeaException e) { - if (e.getStatusCode() == 404) { - // Collection does not exist, create it - String metadata = JSON.toJSONString(new JSONObject() - .fluentPut("refDocId", "text") - .fluentPut("content", "text") - .fluentPut("metadata", "jsonb")); - - String fullTextRetrievalFields = "content"; - CreateCollectionRequest createRequest = new CreateCollectionRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setManagerAccount(this.config.getManagerAccount()) - .setManagerAccountPassword(this.config.getManagerAccountPassword()) - .setNamespace(this.config.getNamespace()) - .setCollection(this.collectionName) - .setDimension(embeddingDimension) - .setMetrics(this.config.getMetrics()) - .setMetadata(metadata) - .setFullTextRetrievalFields(fullTextRetrievalFields); - this.client.createCollection(createRequest); - } else { - throw new RuntimeException("Failed to create collection " + this.collectionName + ": " + e.getMessage()); - } - // Set the cache key to indicate the collection exists - redisClientService.setKeyValue(cacheKey, "1", 3600); - } - } - }catch(Exception e){ - redisClientService.deleteLock(lockName); - throw new RuntimeException("Failed to create collection " + this.collectionName + ": " + e.getMessage()); - }finally { - redisClientService.deleteLock(lockName); - } - } - - public void create(List texts, List> embeddings) throws Exception { - long dimension = embeddings.get(0).size(); - createCollectionIfNotExists(dimension); - addTexts(texts, embeddings); - } - - @Override - public void add(List texts) { - try { - createCollectionIfNotExists(this.config.getEmbeddingDimension()); - addTexts(texts); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public void addTexts(List documents, List> embeddings) throws Exception { - List rows = new ArrayList<>(10); - for (int i = 0; i < documents.size(); i++) { - Document doc = documents.get(i); - List embedding = embeddings.get(i); - - Map metadata = new HashMap<>(); - metadata.put("id", (String) doc.getMetadata().get("id")); - metadata.put("content", doc.getContent()); - metadata.put("metadata", JSON.toJSONString(doc.getMetadata())); - rows.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(embedding).setMetadata(metadata)); - } - UpsertCollectionDataRequest request = new UpsertCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setRows(rows); - this.client.upsertCollectionData(request); - } - - public void addTexts(List documents) throws Exception { - List rows = new ArrayList<>(10); - for (Document doc : documents) { - float[] floatEmbeddings = doc.getEmbedding(); - List embedding = new ArrayList<>(floatEmbeddings.length); - for (float floatEmbedding : floatEmbeddings) { - embedding.add((double) floatEmbedding); - } - Map metadata = new HashMap<>(); - metadata.put("refDocId", (String) doc.getMetadata().get("docId")); - metadata.put("content", doc.getContent()); - metadata.put("metadata", JSON.toJSONString(doc.getMetadata())); - rows.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(embedding).setMetadata(metadata)); - } - UpsertCollectionDataRequest request = new UpsertCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setRows(rows); - this.client.upsertCollectionData(request); - } - - public boolean textExists(String id) { - QueryCollectionDataRequest request = new QueryCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setMetrics(this.config.getMetrics()) - .setIncludeValues(true) - .setVector(null) - .setContent(null) - .setTopK(1L) - .setFilter("refDocId='" + id + "'"); - - try { - QueryCollectionDataResponse response = this.client.queryCollectionData(request); - return response.getBody().getMatches().getMatch().size() > 0; - } catch (Exception e) { - throw new RuntimeException("Failed to query collection data: " + e.getMessage(), e); - } - } - - @Override - public Optional delete(List ids) { - if (ids.isEmpty()) { - return Optional.of(false); - } - String idsStr = ids.stream() - .map(id -> "'" + id + "'") - .collect(Collectors.joining(", ", "(", ")")); - DeleteCollectionDataRequest request = new DeleteCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setCollectionData(null) - .setCollectionDataFilter("refDocId IN " + idsStr); - try { - DeleteCollectionDataResponse deleteCollectionDataResponse = this.client.deleteCollectionData(request); - return deleteCollectionDataResponse.statusCode.equals(200) ? Optional.of(true) : Optional.of(false); - // Handle response if needed - } catch (Exception e) { - throw new RuntimeException("Failed to delete collection data by IDs: " + e.getMessage(), e); - } - } - - public void deleteByMetadataField(String key, String value) { - DeleteCollectionDataRequest request = new DeleteCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setCollectionData(null) - .setCollectionDataFilter("metadata ->> '" + key + "' = '" + value + "'"); - try { - this.client.deleteCollectionData(request); - // Handle response if needed - } catch (Exception e) { - throw new RuntimeException("Failed to delete collection data by metadata field: " + e.getMessage(), e); - } - } - - - public List searchByVector(List queryVector, Map kwargs) { - Double scoreThreshold = (Double) kwargs.getOrDefault("scoreThreshold", 0.0d); - Boolean includeValues = (Boolean) kwargs.getOrDefault("includeValues", true); - Long topK = (Long) kwargs.getOrDefault("topK", 4); - - QueryCollectionDataRequest request = new QueryCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setIncludeValues(includeValues) - .setMetrics(this.config.getMetrics()) - .setVector(queryVector) - .setContent(null) - .setTopK(topK) - .setFilter(null); - - try { - QueryCollectionDataResponse response = this.client.queryCollectionData(request); - List documents = new ArrayList<>(); - for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch match : response.getBody().getMatches().getMatch()) { - if (match.getScore() != null && match.getScore() > scoreThreshold) { - Map metadata = match.getMetadata(); - String pageContent = metadata.get("content"); - Map metadataJson = JSONObject.parseObject(metadata.get("metadata"), HashMap.class); - Document doc = new Document(pageContent, metadataJson); - documents.add(doc); - } - } - return documents; - } catch (Exception e) { - throw new RuntimeException("Failed to search by vector: " + e.getMessage(), e); - } - } - - @Override - public List similaritySearch(String query) { - - return similaritySearch(SearchRequest.query(query)); - - } - - @Override - public List similaritySearch(SearchRequest searchRequest) { - double scoreThreshold = searchRequest.getSimilarityThreshold(); - boolean includeValues = searchRequest.hasFilterExpression(); - int topK = searchRequest.getTopK(); - - QueryCollectionDataRequest request = new QueryCollectionDataRequest() - .setDBInstanceId(this.config.getDBInstanceId()) - .setRegionId(this.config.getRegionId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setCollection(this.collectionName) - .setIncludeValues(includeValues) - .setMetrics(this.config.getMetrics()) - .setVector(null) - .setContent(searchRequest.query) - .setTopK((long) topK) - .setFilter(null); - try { - QueryCollectionDataResponse response = this.client.queryCollectionData(request); - List documents = new ArrayList<>(); - for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch match : response.getBody().getMatches().getMatch()) { - if (match.getScore() != null && match.getScore() > scoreThreshold) { -// System.out.println(match.getScore()); - Map metadata = match.getMetadata(); - String pageContent = metadata.get("content"); - Map metadataJson = JSONObject.parseObject(metadata.get("metadata"), HashMap.class); - Document doc = new Document(pageContent, metadataJson); - documents.add(doc); - } - } - return documents; - } catch (Exception e) { - throw new RuntimeException("Failed to search by full text: " + e.getMessage(), e); - } - } - - public void deleteAll() { - DeleteCollectionRequest request = new DeleteCollectionRequest() - .setCollection(this.collectionName) - .setDBInstanceId(this.config.getDBInstanceId()) - .setNamespace(this.config.getNamespace()) - .setNamespacePassword(this.config.getNamespacePassword()) - .setRegionId(this.config.getRegionId()); - - try { - this.client.deleteCollection(request); - } catch (Exception e) { - throw new RuntimeException("Failed to delete collection: " + e.getMessage(), e); - } - } + private static final Logger logger = LoggerFactory.getLogger(AnalyticdbVector.class); + + private Boolean initialized = false; + + private final String collectionName; + + private AnalyticdbConfig config; + + private Client client; + + public AnalyticdbVector(String collectionName, AnalyticdbConfig config) throws Exception { + // collection_name must be updated every time + this.collectionName = collectionName.toLowerCase(); + this.config = config; + Config clientConfig = Config.build(this.config.toAnalyticdbClientParams()); + this.client = new Client(clientConfig); + logger.debug("created AnalyticdbVector client success"); + } + + /** + * initialize vector db + */ + private void initialize() throws Exception { + if (!this.initialized) { + initializeVectorDataBase(); + createNameSpaceIfNotExists(); + createCollectionIfNotExists(this.config.getEmbeddingDimension()); + this.initialized = true; + } + } + + private void initializeVectorDataBase() throws Exception { + InitVectorDatabaseRequest request = new InitVectorDatabaseRequest().setDBInstanceId(config.getDBInstanceId()) + .setRegionId(config.getRegionId()) + .setManagerAccount(config.getManagerAccount()) + .setManagerAccountPassword(config.getManagerAccountPassword()); + InitVectorDatabaseResponse initVectorDatabaseResponse = client.initVectorDatabase(request); + logger.debug("successfully initialize vector database, response body:{}", initVectorDatabaseResponse.getBody()); + + } + + private void createNameSpaceIfNotExists() throws Exception { + try { + DescribeNamespaceRequest request = new DescribeNamespaceRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()); + this.client.describeNamespace(request); + } + catch (TeaException e) { + if (Objects.equals(e.getStatusCode(), 404)) { + CreateNamespaceRequest request = new CreateNamespaceRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()) + .setNamespacePassword(this.config.getNamespacePassword()); + this.client.createNamespace(request); + } + else { + throw new Exception("failed to create namespace:{}", e); + } + } + } + + private void createCollectionIfNotExists(Long embeddingDimension) throws Exception { + try { + // Describe the collection to check if it exists + DescribeCollectionRequest describeRequest = new DescribeCollectionRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName); + this.client.describeCollection(describeRequest); + logger.debug("collection" + this.collectionName + "already exists"); + } + catch (TeaException e) { + if (Objects.equals(e.getStatusCode(), 404)) { + // Collection does not exist, create it + String metadata = JSON.toJSONString(new JSONObject().fluentPut("refDocId", "text") + .fluentPut("content", "text") + .fluentPut("metadata", "jsonb")); + String fullTextRetrievalFields = "content"; + CreateCollectionRequest createRequest = new CreateCollectionRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setManagerAccount(this.config.getManagerAccount()) + .setManagerAccountPassword(this.config.getManagerAccountPassword()) + .setNamespace(this.config.getNamespace()) + .setCollection(this.collectionName) + .setDimension(embeddingDimension) + .setMetrics(this.config.getMetrics()) + .setMetadata(metadata) + .setFullTextRetrievalFields(fullTextRetrievalFields); + this.client.createCollection(createRequest); + logger.debug("collection" + this.collectionName + "created"); + } + else { + throw new RuntimeException( + "Failed to create collection " + this.collectionName + ": " + e.getMessage()); + } + } + } + + @Override + public void add(List texts) { + try { + initialize(); + addTexts(texts); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + public void addTexts(List documents) throws Exception { + List rows = new ArrayList<>(10); + for (Document doc : documents) { + float[] floatEmbeddings = doc.getEmbedding(); + List embedding = new ArrayList<>(floatEmbeddings.length); + for (float floatEmbedding : floatEmbeddings) { + embedding.add((double) floatEmbedding); + } + Map metadata = new HashMap<>(); + metadata.put("refDocId", (String) doc.getMetadata().get("docId")); + metadata.put("content", doc.getContent()); + metadata.put("metadata", JSON.toJSONString(doc.getMetadata())); + rows.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(embedding) + .setMetadata(metadata)); + } + UpsertCollectionDataRequest request = new UpsertCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setRows(rows); + this.client.upsertCollectionData(request); + } + + @Override + public Optional delete(List ids) { + try { + initialize(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + return deleteByIds(ids); + } + + public Optional deleteByIds(List ids) { + if (ids.isEmpty()) { + return Optional.of(false); + } + String idsStr = ids.stream().map(id -> "'" + id + "'").collect(Collectors.joining(", ", "(", ")")); + DeleteCollectionDataRequest request = new DeleteCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setCollectionData(null) + .setCollectionDataFilter("refDocId IN " + idsStr); + try { + DeleteCollectionDataResponse deleteCollectionDataResponse = this.client.deleteCollectionData(request); + return deleteCollectionDataResponse.statusCode.equals(200) ? Optional.of(true) : Optional.of(false); + // Handle response if needed + } + catch (Exception e) { + throw new RuntimeException("Failed to delete collection data by IDs: " + e.getMessage(), e); + } + } + + @Override + public List similaritySearch(String query) { + try { + initialize(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + return similaritySearch(SearchRequest.query(query)); + + } + + @Override + public List similaritySearch(SearchRequest searchRequest) { + double scoreThreshold = searchRequest.getSimilarityThreshold(); + boolean includeValues = searchRequest.hasFilterExpression(); + int topK = searchRequest.getTopK(); + + QueryCollectionDataRequest request = new QueryCollectionDataRequest() + .setDBInstanceId(this.config.getDBInstanceId()) + .setRegionId(this.config.getRegionId()) + .setNamespace(this.config.getNamespace()) + .setNamespacePassword(this.config.getNamespacePassword()) + .setCollection(this.collectionName) + .setIncludeValues(includeValues) + .setMetrics(this.config.getMetrics()) + .setVector(null) + .setContent(searchRequest.query) + .setTopK((long) topK) + .setFilter(null); + try { + QueryCollectionDataResponse response = this.client.queryCollectionData(request); + List documents = new ArrayList<>(); + for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch match : response.getBody() + .getMatches() + .getMatch()) { + if (match.getScore() != null && match.getScore() > scoreThreshold) { + Map metadata = match.getMetadata(); + String pageContent = metadata.get("content"); + Map metadataJson = JSONObject.parseObject(metadata.get("metadata"), HashMap.class); + Document doc = new Document(pageContent, metadataJson); + documents.add(doc); + } + } + return documents; + } + catch (Exception e) { + throw new RuntimeException("Failed to search by full text: " + e.getMessage(), e); + } + } }