Techbrad

[FastAPI x Langchain] ChatGPT 응답 Streaming 구현 본문

Framework

[FastAPI x Langchain] ChatGPT 응답 Streaming 구현

brad.min 2024. 3. 6. 16:13
반응형

RAG 어플리케이션을 개발하면서 LLM에서 생성하는 텍스트를 스트리밍으로 한글자씩 나오도록 구현해보았다.

큐를 사용하여 글자를 순차적으로 전달하였고 Langchain의 BaseCallbackHandler 를 사용했다. BaseCallbackHandler에 대해서는 조금 더 공부가 필요한 듯 보인다.

 

 

API

class Question(BaseModel):
    question: str

@router.post(
    path='/text-stream/',
    description="Enter the question"
)
@inject
async def generate_text_streaming(
        query: Question,
        rag_service: Rag_Service = Depends(Provide[Container.rag_service]),
):
    return StreamingResponse(rag_service.generate_text_streaming(query), media_type='text/event-stream')

 

 

Service

class Rag_Service:
    def __init__(self):
		self.streamer_queue = Queue()
        self.streaming_handler = StreamingHandler(queue=self.streamer_queue)
        self.LLM = ChatOpenAI(
            streaming=True,
            callbacks=[self.streaming_handler]
        )

    def generate(self, llm, text):
        llm.invoke(text)

    def start_generation(self, llm, text):
        thread = Thread(target=self.generate, kwargs={"llm": llm, "text": text})
        thread.start()
	
    async def generate_text_streaming(self, text: Question):
    	self.start_generation(self.LLM, text.question)
        while True:
            value = self.streamer_queue.get()
            if value == None:
                break
            yield value
            self.streamer_queue.task_done()
            await asyncio.sleep(0.1)

 

 

StreamingHandler

class StreamingHandler(BaseCallbackHandler):
    def __init__(self, queue) -> None:
        super().__init__()
        self._queue = queue
        self._stop_signal = None
        print("Custom handler Initialized")

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self._queue.put(token)

    def on_llm_start(
            self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        print("generation started")

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        print("\n\ngeneration concluded")
        self._queue.put(self._stop_signal)
반응형