Spaces:
Paused
Paused
| import flask_restful | |
| from flask import current_app, request | |
| from flask_login import current_user | |
| from flask_restful import Resource, marshal, marshal_with, reqparse | |
| from werkzeug.exceptions import Forbidden, NotFound | |
| import services | |
| from controllers.console import api | |
| from controllers.console.apikey import api_key_fields, api_key_list | |
| from controllers.console.app.error import ProviderNotInitializeError | |
| from controllers.console.datasets.error import DatasetNameDuplicateError | |
| from controllers.console.setup import setup_required | |
| from controllers.console.wraps import account_initialization_required | |
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |
| from core.indexing_runner import IndexingRunner | |
| from core.model_runtime.entities.model_entities import ModelType | |
| from core.provider_manager import ProviderManager | |
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |
| from extensions.ext_database import db | |
| from fields.app_fields import related_app_list | |
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |
| from fields.document_fields import document_status_fields | |
| from libs.login import login_required | |
| from models.dataset import Dataset, Document, DocumentSegment | |
| from models.model import ApiToken, UploadFile | |
| from services.dataset_service import DatasetService, DocumentService | |
| def _validate_name(name): | |
| if not name or len(name) < 1 or len(name) > 40: | |
| raise ValueError('Name must be between 1 to 40 characters.') | |
| return name | |
| def _validate_description_length(description): | |
| if len(description) > 400: | |
| raise ValueError('Description cannot exceed 400 characters.') | |
| return description | |
| class DatasetListApi(Resource): | |
| def get(self): | |
| page = request.args.get('page', default=1, type=int) | |
| limit = request.args.get('limit', default=20, type=int) | |
| ids = request.args.getlist('ids') | |
| provider = request.args.get('provider', default="vendor") | |
| search = request.args.get('keyword', default=None, type=str) | |
| tag_ids = request.args.getlist('tag_ids') | |
| if ids: | |
| datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) | |
| else: | |
| datasets, total = DatasetService.get_datasets(page, limit, provider, | |
| current_user.current_tenant_id, current_user, search, tag_ids) | |
| # check embedding setting | |
| provider_manager = ProviderManager() | |
| configurations = provider_manager.get_configurations( | |
| tenant_id=current_user.current_tenant_id | |
| ) | |
| embedding_models = configurations.get_models( | |
| model_type=ModelType.TEXT_EMBEDDING, | |
| only_active=True | |
| ) | |
| model_names = [] | |
| for embedding_model in embedding_models: | |
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
| data = marshal(datasets, dataset_detail_fields) | |
| for item in data: | |
| if item['indexing_technique'] == 'high_quality': | |
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |
| if item_model in model_names: | |
| item['embedding_available'] = True | |
| else: | |
| item['embedding_available'] = False | |
| else: | |
| item['embedding_available'] = True | |
| response = { | |
| 'data': data, | |
| 'has_more': len(datasets) == limit, | |
| 'limit': limit, | |
| 'total': total, | |
| 'page': page | |
| } | |
| return response, 200 | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument('name', nullable=False, required=True, | |
| help='type is required. Name must be between 1 to 40 characters.', | |
| type=_validate_name) | |
| parser.add_argument('indexing_technique', type=str, location='json', | |
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
| nullable=True, | |
| help='Invalid indexing technique.') | |
| args = parser.parse_args() | |
| # The role of the current user in the ta table must be admin or owner | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| try: | |
| dataset = DatasetService.create_empty_dataset( | |
| tenant_id=current_user.current_tenant_id, | |
| name=args['name'], | |
| indexing_technique=args['indexing_technique'], | |
| account=current_user | |
| ) | |
| except services.errors.dataset.DatasetNameDuplicateError: | |
| raise DatasetNameDuplicateError() | |
| return marshal(dataset, dataset_detail_fields), 201 | |
| class DatasetApi(Resource): | |
| def get(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| try: | |
| DatasetService.check_dataset_permission( | |
| dataset, current_user) | |
| except services.errors.account.NoPermissionError as e: | |
| raise Forbidden(str(e)) | |
| data = marshal(dataset, dataset_detail_fields) | |
| # check embedding setting | |
| provider_manager = ProviderManager() | |
| configurations = provider_manager.get_configurations( | |
| tenant_id=current_user.current_tenant_id | |
| ) | |
| embedding_models = configurations.get_models( | |
| model_type=ModelType.TEXT_EMBEDDING, | |
| only_active=True | |
| ) | |
| model_names = [] | |
| for embedding_model in embedding_models: | |
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
| if data['indexing_technique'] == 'high_quality': | |
| item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" | |
| if item_model in model_names: | |
| data['embedding_available'] = True | |
| else: | |
| data['embedding_available'] = False | |
| else: | |
| data['embedding_available'] = True | |
| return data, 200 | |
| def patch(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| # check user's model setting | |
| DatasetService.check_dataset_model_setting(dataset) | |
| parser = reqparse.RequestParser() | |
| parser.add_argument('name', nullable=False, | |
| help='type is required. Name must be between 1 to 40 characters.', | |
| type=_validate_name) | |
| parser.add_argument('description', | |
| location='json', store_missing=False, | |
| type=_validate_description_length) | |
| parser.add_argument('indexing_technique', type=str, location='json', | |
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
| nullable=True, | |
| help='Invalid indexing technique.') | |
| parser.add_argument('permission', type=str, location='json', choices=( | |
| 'only_me', 'all_team_members'), help='Invalid permission.') | |
| parser.add_argument('embedding_model', type=str, | |
| location='json', help='Invalid embedding model.') | |
| parser.add_argument('embedding_model_provider', type=str, | |
| location='json', help='Invalid embedding model provider.') | |
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |
| args = parser.parse_args() | |
| # The role of the current user in the ta table must be admin or owner | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| dataset = DatasetService.update_dataset( | |
| dataset_id_str, args, current_user) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| return marshal(dataset, dataset_detail_fields), 200 | |
| def delete(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| # The role of the current user in the ta table must be admin or owner | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| if DatasetService.delete_dataset(dataset_id_str, current_user): | |
| return {'result': 'success'}, 204 | |
| else: | |
| raise NotFound("Dataset not found.") | |
| class DatasetQueryApi(Resource): | |
| def get(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| try: | |
| DatasetService.check_dataset_permission(dataset, current_user) | |
| except services.errors.account.NoPermissionError as e: | |
| raise Forbidden(str(e)) | |
| page = request.args.get('page', default=1, type=int) | |
| limit = request.args.get('limit', default=20, type=int) | |
| dataset_queries, total = DatasetService.get_dataset_queries( | |
| dataset_id=dataset.id, | |
| page=page, | |
| per_page=limit | |
| ) | |
| response = { | |
| 'data': marshal(dataset_queries, dataset_query_detail_fields), | |
| 'has_more': len(dataset_queries) == limit, | |
| 'limit': limit, | |
| 'total': total, | |
| 'page': page | |
| } | |
| return response, 200 | |
| class DatasetIndexingEstimateApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | |
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |
| parser.add_argument('indexing_technique', type=str, required=True, | |
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
| nullable=True, location='json') | |
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |
| parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') | |
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |
| location='json') | |
| args = parser.parse_args() | |
| # validate args | |
| DocumentService.estimate_args_validate(args) | |
| extract_settings = [] | |
| if args['info_list']['data_source_type'] == 'upload_file': | |
| file_ids = args['info_list']['file_info_list']['file_ids'] | |
| file_details = db.session.query(UploadFile).filter( | |
| UploadFile.tenant_id == current_user.current_tenant_id, | |
| UploadFile.id.in_(file_ids) | |
| ).all() | |
| if file_details is None: | |
| raise NotFound("File not found.") | |
| if file_details: | |
| for file_detail in file_details: | |
| extract_setting = ExtractSetting( | |
| datasource_type="upload_file", | |
| upload_file=file_detail, | |
| document_model=args['doc_form'] | |
| ) | |
| extract_settings.append(extract_setting) | |
| elif args['info_list']['data_source_type'] == 'notion_import': | |
| notion_info_list = args['info_list']['notion_info_list'] | |
| for notion_info in notion_info_list: | |
| workspace_id = notion_info['workspace_id'] | |
| for page in notion_info['pages']: | |
| extract_setting = ExtractSetting( | |
| datasource_type="notion_import", | |
| notion_info={ | |
| "notion_workspace_id": workspace_id, | |
| "notion_obj_id": page['page_id'], | |
| "notion_page_type": page['type'], | |
| "tenant_id": current_user.current_tenant_id | |
| }, | |
| document_model=args['doc_form'] | |
| ) | |
| extract_settings.append(extract_setting) | |
| else: | |
| raise ValueError('Data source type not support') | |
| indexing_runner = IndexingRunner() | |
| try: | |
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |
| args['process_rule'], args['doc_form'], | |
| args['doc_language'], args['dataset_id'], | |
| args['indexing_technique']) | |
| except LLMBadRequestError: | |
| raise ProviderNotInitializeError( | |
| "No Embedding Model available. Please configure a valid provider " | |
| "in the Settings -> Model Provider.") | |
| except ProviderTokenNotInitError as ex: | |
| raise ProviderNotInitializeError(ex.description) | |
| return response, 200 | |
| class DatasetRelatedAppListApi(Resource): | |
| def get(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| try: | |
| DatasetService.check_dataset_permission(dataset, current_user) | |
| except services.errors.account.NoPermissionError as e: | |
| raise Forbidden(str(e)) | |
| app_dataset_joins = DatasetService.get_related_apps(dataset.id) | |
| related_apps = [] | |
| for app_dataset_join in app_dataset_joins: | |
| app_model = app_dataset_join.app | |
| if app_model: | |
| related_apps.append(app_model) | |
| return { | |
| 'data': related_apps, | |
| 'total': len(related_apps) | |
| }, 200 | |
| class DatasetIndexingStatusApi(Resource): | |
| def get(self, dataset_id): | |
| dataset_id = str(dataset_id) | |
| documents = db.session.query(Document).filter( | |
| Document.dataset_id == dataset_id, | |
| Document.tenant_id == current_user.current_tenant_id | |
| ).all() | |
| documents_status = [] | |
| for document in documents: | |
| completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), | |
| DocumentSegment.document_id == str(document.id), | |
| DocumentSegment.status != 're_segment').count() | |
| total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), | |
| DocumentSegment.status != 're_segment').count() | |
| document.completed_segments = completed_segments | |
| document.total_segments = total_segments | |
| documents_status.append(marshal(document, document_status_fields)) | |
| data = { | |
| 'data': documents_status | |
| } | |
| return data | |
| class DatasetApiKeyApi(Resource): | |
| max_keys = 10 | |
| token_prefix = 'dataset-' | |
| resource_type = 'dataset' | |
| def get(self): | |
| keys = db.session.query(ApiToken). \ | |
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |
| all() | |
| return {"items": keys} | |
| def post(self): | |
| # The role of the current user in the ta table must be admin or owner | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| current_key_count = db.session.query(ApiToken). \ | |
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |
| count() | |
| if current_key_count >= self.max_keys: | |
| flask_restful.abort( | |
| 400, | |
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |
| code='max_keys_exceeded' | |
| ) | |
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |
| api_token = ApiToken() | |
| api_token.tenant_id = current_user.current_tenant_id | |
| api_token.token = key | |
| api_token.type = self.resource_type | |
| db.session.add(api_token) | |
| db.session.commit() | |
| return api_token, 200 | |
| class DatasetApiDeleteApi(Resource): | |
| resource_type = 'dataset' | |
| def delete(self, api_key_id): | |
| api_key_id = str(api_key_id) | |
| # The role of the current user in the ta table must be admin or owner | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| key = db.session.query(ApiToken). \ | |
| filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, | |
| ApiToken.id == api_key_id). \ | |
| first() | |
| if key is None: | |
| flask_restful.abort(404, message='API key not found') | |
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |
| db.session.commit() | |
| return {'result': 'success'}, 204 | |
| class DatasetApiBaseUrlApi(Resource): | |
| def get(self): | |
| return { | |
| 'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] | |
| else request.host_url.rstrip('/')) + '/v1' | |
| } | |
| class DatasetRetrievalSettingApi(Resource): | |
| def get(self): | |
| vector_type = current_app.config['VECTOR_STORE'] | |
| if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}: | |
| return { | |
| 'retrieval_method': [ | |
| 'semantic_search' | |
| ] | |
| } | |
| elif vector_type in {"qdrant", "weaviate"}: | |
| return { | |
| 'retrieval_method': [ | |
| 'semantic_search', 'full_text_search', 'hybrid_search' | |
| ] | |
| } | |
| else: | |
| raise ValueError("Unsupported vector db type.") | |
| class DatasetRetrievalSettingMockApi(Resource): | |
| def get(self, vector_type): | |
| if vector_type in {'milvus', 'relyt', 'pgvector'}: | |
| return { | |
| 'retrieval_method': [ | |
| 'semantic_search' | |
| ] | |
| } | |
| elif vector_type in {'qdrant', 'weaviate'}: | |
| return { | |
| 'retrieval_method': [ | |
| 'semantic_search', 'full_text_search', 'hybrid_search' | |
| ] | |
| } | |
| else: | |
| raise ValueError("Unsupported vector db type.") | |
| class DatasetErrorDocs(Resource): | |
| def get(self, dataset_id): | |
| dataset_id_str = str(dataset_id) | |
| dataset = DatasetService.get_dataset(dataset_id_str) | |
| if dataset is None: | |
| raise NotFound("Dataset not found.") | |
| results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) | |
| return { | |
| 'data': [marshal(item, document_status_fields) for item in results], | |
| 'total': len(results) | |
| }, 200 | |
| api.add_resource(DatasetListApi, '/datasets') | |
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |
| api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | |
| api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs') | |
| api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') | |
| api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') | |
| api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') | |
| api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | |
| api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |