Skip to content

Commit

Permalink
Merge branch 'feat/0.1.9' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
dolphin0618 committed Oct 31, 2023
2 parents 7e19876 + c454b6f commit 70e4f53
Show file tree
Hide file tree
Showing 128 changed files with 14,465 additions and 2,985 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
run: |
cd ./src/backend
pip install bisheng_langchain==$RELEASE_VERSION
sed -i 's/^bisheng_langchain.*/bisheng_langchain = "$RELEASE_VERSION"/g' pyproject.toml
sed -i 's/^bisheng_langchain.*/bisheng_langchain = "'$RELEASE_VERSION'"/g' pyproject.toml
poetry lock
cd ../../
Expand Down
26 changes: 0 additions & 26 deletions .vscode/launch.json

This file was deleted.

4 changes: 2 additions & 2 deletions docker/redis/redis.conf
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,15 @@ dir ./
# starting the replication synchronization process, otherwise the master will
# refuse the replica request.
#
# masterauth <master-password>
masterauth E1SkG0PaDMEPTAxY
#
# However this is not enough if you are using Redis ACLs (for Redis version
# 6 or greater), and the default user is not capable of running the PSYNC
# command and/or other commands needed for replication. In this case it's
# better to configure a special user to use with replication, and specify the
# masteruser configuration as such:
#
# masteruser <username>
masteruser bisheng
#
# When masteruser is specified, the replica will authenticate against its
# master using the new AUTH form: AUTH <username> <password>.
Expand Down
3 changes: 2 additions & 1 deletion src/backend/bisheng/api/JWT.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from bisheng.settings import settings
from pydantic import BaseModel


class Settings(BaseModel):
authjwt_secret_key: str = 'xI$xO.oN$sC}tC^oQ(fF^nK~dB&uT('
authjwt_secret_key: str = settings.jwt_secret
# Configure application to store and get JWT from cookies
authjwt_token_location: set = {'cookies'}
# Disable CSRF Protection for this example. default is True
Expand Down
8 changes: 6 additions & 2 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def build_input_keys_response(langchain_object, artifacts):
return input_keys_response


def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, chat_id=None):
def build_flow(graph_data: dict,
artifacts,
process_file=False,
flow_id=None,
chat_id=None) -> Graph:
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
Expand Down Expand Up @@ -193,6 +197,6 @@ def access_check(payload: dict, owner_user_id: int, target_id: int, type: Access
select(RoleAccess).where(RoleAccess.role_id.in_(payload.get('role')),
RoleAccess.type == type.value)).all()
third_ids = [access.third_id for access in role_access]
if owner_user_id != payload.get('user_id') and not third_ids and target_id not in third_ids:
if owner_user_id != payload.get('user_id') and str(target_id) not in third_ids:
return False
return True
30 changes: 27 additions & 3 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from typing import Any, Dict, List, Union

from bisheng.api.v1.schemas import ChatResponse
from bisheng.utils.logger import logger
from fastapi import WebSocket
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentFinish, LLMResult
from langchain.schema.agent import AgentAction
from langchain.schema.document import Document


# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
Expand All @@ -19,22 +21,27 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type='stream', intermediate_steps='')
await self.websocket.send_json(resp.dict())

async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str],
**kwargs: Any) -> Any:
"""Run when LLM starts running."""
logger.debug(f'llm_start prompts={prompts}')

async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
"""Run when LLM ends running."""
logger.debug(f'llm_end response={response}')

async def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when LLM errors."""

async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any],
**kwargs: Any) -> Any:
"""Run when chain starts running."""

async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""

async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
async def on_chain_error(self, error: Union[Exception, KeyboardInterrupt],
**kwargs: Any) -> Any:
"""Run when chain errors."""

async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -100,6 +107,15 @@ async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
)
await self.websocket.send_json(resp.dict())

async def on_retriever_start(self, serialized: Dict[str, Any], query: str,
**kwargs: Any) -> Any:
"""Run when retriever start running."""

async def on_retriever_end(self, result: List[Document], **kwargs: Any) -> Any:
"""Run when retriever end running."""
# todo 判断技能权限
logger.debug(f'retriver_result result={result}')


class StreamingLLMCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming LLM responses."""
Expand Down Expand Up @@ -178,3 +194,11 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
asyncio.run_coroutine_threadsafe(coroutine, loop)
except Exception as e:
logger.error(e)

def on_retriever_start(self, serialized: Dict[str, Any], query: str, **kwargs: Any) -> Any:
"""Run when retriever start running."""

def on_retriever_end(self, result: List[Document], **kwargs: Any) -> Any:
"""Run when retriever end running."""
# todo 判断技能权限
logger.debug(f'retriver_result result={result}')
16 changes: 15 additions & 1 deletion src/backend/bisheng/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Optional

import yaml
from bisheng import settings
from bisheng.api.v1.schemas import ProcessResponse, UploadFileResponse
from bisheng.cache.redis import redis_client
from bisheng.cache.utils import save_uploaded_file
Expand All @@ -12,6 +14,7 @@
from bisheng.settings import parse_key
from bisheng.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import delete
from sqlmodel import Session, select

Expand All @@ -24,8 +27,17 @@ def get_all():
return langchain_types_dict


@router.get('/env')
def getn_env():
return {'data': settings.settings.environment}


@router.get('/config')
def get_config(session: Session = Depends(get_session)):
def get_config(session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
if payload.get('role') != 'admin':
raise HTTPException(status_code=500, detail='Unauthorized')
configs = session.exec(select(Config)).all()
config_str = []
for config in configs:
Expand Down Expand Up @@ -71,6 +83,8 @@ async def process_flow(
"""
Endpoint to process an input with a given flow_id.
"""
if inputs and isinstance(inputs, dict):
inputs.pop('id')

try:
flow = session.get(Flow, flow_id)
Expand Down
8 changes: 5 additions & 3 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
from uuid import UUID

from bisheng.api.utils import build_flow_no_yield, remove_api_keys
from bisheng.api.utils import access_check, build_flow_no_yield, remove_api_keys
from bisheng.api.v1.schemas import FlowListCreate, FlowListRead
from bisheng.database.base import get_session
from bisheng.database.models.flow import Flow, FlowCreate, FlowRead, FlowReadWithStyle, FlowUpdate
Expand Down Expand Up @@ -57,7 +57,7 @@ def read_flows(*,
select(RoleAccess).where(RoleAccess.role_id.in_(payload.get('role')))).all()
if rol_flow_id:
flow_ids = [
acess.third_id for acess in rol_flow_id if acess.type == AccessType.FLOW
acess.third_id for acess in rol_flow_id if acess.type == AccessType.FLOW.value
]
sql = sql.where(or_(Flow.user_id == payload.get('user_id'), Flow.id.in_(flow_ids)))
count_sql = count_sql.where(
Expand Down Expand Up @@ -85,6 +85,8 @@ def read_flows(*,
userMap = {user.user_id: user.user_name for user in db_user}
for r in res:
r['user_name'] = userMap[r['user_id']]
r['write'] = True if 'admin' == payload.get('role') or r.get(
'user_id') == payload.get('user_id') else False

return {'data': res, 'total': total_count}

Expand Down Expand Up @@ -114,7 +116,7 @@ def update_flow(*,
if not db_flow:
raise HTTPException(status_code=404, detail='Flow not found')

if 'admin' != payload.get('role') and db_flow.user_id != payload.get('user_id'):
if not access_check(payload, db_flow.user_id, flow_id, AccessType.FLOW_WRITE):
raise HTTPException(status_code=500, detail='没有权限编辑此技能')

flow_data = flow.dict(exclude_unset=True)
Expand Down
32 changes: 23 additions & 9 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def get_knowledge(*,
select(RoleAccess).where(RoleAccess.role_id.in_(payload.get('role')))).all()
if role_third_id:
third_ids = [
acess.third_id for acess in role_third_id if acess.type == AccessType.KNOWLEDGE
acess.third_id
for acess in role_third_id
if acess.type == AccessType.KNOWLEDGE.value
]
sql = sql.where(
or_(Knowledge.user_id == payload.get('user_id'), Knowledge.id.in_(third_ids)))
Expand Down Expand Up @@ -249,7 +251,7 @@ def delete_knowledge(*,
knowledge = session.get(Knowledge, knowledge_id)
if not knowledge:
raise HTTPException(status_code=404, detail='knowledge not found')
if 'admin' != payload.get('role') and knowledge.user_id != payload.get('user_id'):
if not access_check(payload, knowledge.user_id, knowledge_id, AccessType.KNOWLEDGE_WRITE):
raise HTTPException(status_code=404, detail='没有权限执行操作')
session.delete(knowledge)
session.commit()
Expand All @@ -267,9 +269,10 @@ def delete_knowledge_file(*,
knowledge_file = session.get(KnowledgeFile, file_id)
if not knowledge_file:
raise HTTPException(status_code=404, detail='文件不存在')
if 'admin' != payload.get('role') and knowledge_file.user_id != payload.get('user_id'):
raise HTTPException(status_code=404, detail='没有权限执行操作')

knowledge = session.get(Knowledge, knowledge_file.knowledge_id)
if not access_check(payload, knowledge.user_id, knowledge.id, AccessType.KNOWLEDGE_WRITE):
raise HTTPException(status_code=404, detail='没有权限执行操作')
# 处理vectordb
collection_name = knowledge.collection_name
embeddings = decide_embeddings(knowledge.model)
Expand All @@ -282,6 +285,13 @@ def delete_knowledge_file(*,
# minio
minio_client.MinioClient().delete_minio(str(knowledge_file.id))
# elastic
esvectore_client = decide_vectorstores(collection_name, 'ElasticKeywordsSearch', embeddings)
if esvectore_client:
esvectore_client.client.delete_by_query(index=collection_name,
query={'match': {
'metadata.file_id': file_id
}})
logger.info(f'act=delete_es file_id={file_id} res={res}')

session.delete(knowledge_file)
session.commit()
Expand Down Expand Up @@ -326,15 +336,18 @@ async def addEmbedding(collection_name, model: str, chunk_size: int, separator:
for index, path in enumerate(file_paths):
knowledge_file = knowledge_files[index]
try:
texts, metadatas = _read_chunk_text(path, knowledge_file.file_name, chunk_size,
chunk_overlap, separator)
# 存储 mysql
session = next(get_session())
db_file = session.get(KnowledgeFile, knowledge_file.id)
setattr(db_file, 'status', 2)
setattr(db_file, 'object_name', knowledge_file.file_name)
session.add(db_file)
session.commit()
session.refresh(db_file)
session.flush()
# 原文件
minio_client.MinioClient().upload_minio(knowledge_file.file_name, path)

texts, metadatas = _read_chunk_text(path, knowledge_file.file_name, chunk_size,
chunk_overlap, separator)

# 溯源必须依赖minio, 后期替换更通用的oss
minio_client.MinioClient().upload_minio(str(db_file.id), path)
Expand All @@ -346,7 +359,8 @@ async def addEmbedding(collection_name, model: str, chunk_size: int, separator:
# 存储es
if es_client:
es_client.add_texts(texts=texts, metadatas=metadatas)

session.commit()
session.refresh(db_file)
except Exception as e:
logger.exception(e)
session = next(get_session())
Expand Down
12 changes: 12 additions & 0 deletions src/backend/bisheng/api/v1/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from bisheng.database.base import get_session
from bisheng.database.models.knowledge_file import KnowledgeFile
from bisheng.database.models.recall_chunk import RecallChunk
from bisheng.utils import minio_client
from fastapi import APIRouter, Depends
Expand All @@ -27,12 +28,23 @@ def get_answer_keyword(message_id: int, session: Session = Depends(get_session))
def get_original_file(*, message_id: int, keys: str, session: Session = Depends(get_session)):
# 获取命中的key
chunks = session.exec(select(RecallChunk).where(RecallChunk.message_id == message_id)).all()

if not chunks:
return {'data': [], 'message': 'no chunk found'}

# chunk 的所有file
file_ids = {chunk.file_id for chunk in chunks}
db_knowledge_files = session.exec(select(KnowledgeFile).where(KnowledgeFile.id.in_(file_ids)))
id2file = {file.id: file for file in db_knowledge_files}
# keywords
keywords = keys.split(';') if keys else []
result = []
for index, chunk in enumerate(chunks):
file = id2file.get(chunk.file_id)
chunk_res = json.loads(json.loads(chunk.meta_data).get('bbox'))
chunk_res['source_url'] = minio_client.MinioClient().get_share_link(str(chunk.file_id))
chunk_res['original_url'] = minio_client.MinioClient().get_share_link(
file.object_name if file.object_name else str(file.id))
chunk_res['score'] = round(match_score(chunk.chunk, keywords),
2) if len(keywords) > 0 else 0
chunk_res['file_id'] = chunk.file_id
Expand Down
Loading

0 comments on commit 70e4f53

Please sign in to comment.