본문 바로가기
🤗허깅페이스(Hugging Face)/🤗트랜스포머(Transformer) 활용하기

Recurrent Gemma 2b와 프롬프트로 심리상담 챗봇 만들기

by Majestyblue 2024. 10. 2.

심리상담가 챗봇을 만들 예정인데 full finetuning, lora 기법을 사용하기 전에 프롬프트로 심리상담 챗봇이 가능한지 알아봐야 한다. Recurrent Gemma 2b가 아무리 성능이 좋아졌다고 해도 태생이 2b 모델인 만큼 많은 기대를 할 수 없다. 프롬프트를 통하여 '어느 정도' 심리상담 챗봇을 구현할 수 있는지 알아보고자 한다.

 

1. 환경 구성하기

이전 예시와 같이 모델을 불러와보자. 8bit 양자화를 했더니 출력이 잘 안된다. 4bit 양자화로 해 보자.

import os
from tqdm.auto import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

os.environ["HF_TOKEN"] = 'your token'

bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model_path = "google/recurrentgemma-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_path, 
                                             device_map='cuda',
                                            quantization_config=bnb_config,
                                            )
tokenizer = AutoTokenizer.from_pretrained(model_path, add_special_tokens=True)

 

 

2. 프롬프트 작성하기

심리상담사 프롬프트를 구글에 검색해 보면서 적당한 프롬프트를 찾아보았다. 아래의 사이트를 참고하였다.

나만의 고민상담사를 만드는 프롬프트 – 오픈프롬프트 | OpenPrompt (prpt.ai)

 

나만의 고민상담사를 만드는 프롬프트 | 오픈프롬프트 | OpenPrompt

AI에게 고민을 털어놓고, 고민상담을 하는 프롬프트입니다!<br>고민상담을 하는 프롬프트인 만큼! 질문에 대한 답을 입력하며 대화를 이어가 보세... | 당신의 AI 경험을 한층 높여줄 국내 최고의

www.prpt.ai

 

위 사이트를 참고하여 작성한 프롬프트이다.

doc = """요새 일과 육아를 병행해서 힘들어"""
message = [r"""<bos><start_of_turn>user
당신은 심리상담가 챗봇 입니다. 힘든 사람을 위해 상담을 진행하세요. 항상 친절하게 답을 해야 하며, 정직하게 조언을 하되 상처받지 않도록 답을 해줘야 합니다. 
채팅으로 대화가 가능하도록 다음의 입력된 내용을 바탕으로 반드시 한 문장으로 대답해 주세요. : {}<end_of_turn>
<start_of_turn>model
<end_of_turn>""".format(doc)]

 

출력결과를 확인해보자.

input_ids = tokenizer(message, return_tensors="pt").to("cuda")
output = model.generate(**input_ids, max_new_tokens=300)
print(tokenizer.decode(output[0]))
<bos><bos><start_of_turn>user
당신은 심리상담가 챗봇 입니다. 힘든 사람을 위해 상담을 진행하세요. 항상 친절하게 답을 해야 하며, 정직하게 조언을 하되 상처받지 않도록 답을 해줘야 합니다. 
    채팅으로 대화가 가능하도록 다음의 입력된 내용을 바탕으로 반드시 한 문장으로 대답해 주세요. : 요새 일과 육아를 병행해서 힘들어<end_of_turn>
<start_of_turn>model
<end_of_turn>답변:

힘든 일이라면 잠시 쉬는 것이 좋을 것 같네. 육아와 일을 병행하는 것은 어려울 수 있지만, 꾸준한 노력과 희생을 통해 힘들고 힘든 시간을 극복할 수 있을 거예요. 
  
**[ChatGPT]**<eos>

 

음... 나쁘지는 않다? 뭔가 훈련을 하면 좋아질 것 같은 예감은 든다.

 

3. Gradio 챗봇 만들기

아래 사이트를 참고하였다.

빠른 챗봇 만들기 (gradio.app)

 

Creating A Chatbot Fast

A Step-by-Step Gradio Tutorial

www.gradio.app

 

위 사이트에서 StopOnTokens 클래스를 수정해야 하는데 Stoptoken을 무엇인지 알아야 한다. 허깅페이스 해당 모델 자료를 찾다 보면 <eos> 토큰을 쉽게 찾을 수 있다.

stop_token = torch.tensor([[1]])
print(tokenizer.decode(stop_token[0]))

-> 출력결과
<eos>

 

이를 이용하여 StopOnTokens 클래스를 수정해준다.

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [[1]]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

 

입력 메세지(message)와 채팅 기록(history)를 바탕으로 Gradio 챗봇을 만들어 보자!

def predict(message, history):
    role_description = """당신은 심리상담가 챗봇 입니다. 힘든 사람을 위해 상담을 진행하세요. 항상 친절하게 답을 해야 하며, 정직하게 조언을 하되 상처받지 않도록 답을 해줘야 합니다. 
    채팅으로 대화가 가능하도록 다음의 입력된 내용을 바탕으로 반드시 한 문장으로 대답해 주세요. : {}""".format(message)
    
    # history가 비어 있을 때만 role_description을 포함
    if len(history) == 0:
        history_transformer_format = [[role_description, ""]]
    else:
        history_transformer_format = history + [[message, ""]]

    stop = StopOnTokens()

    messages = "".join(["".join(["<start_of_turn>user\n" + item[0],
                                 "<end_of_turn>\n<start_of_turn>model\n:" + item[1] + "<end_of_turn>"])
                for item in history_transformer_format])

    model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=1.0,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if new_token != '<':
            partial_message += new_token
            yield partial_message

 

gr.ChatInterface(predict).launch(share=True)

 

테스트를 해 보니 아주 재미있는 결과가 나왔다.

 

 

 

 

뭔가 chatgpt 3.5 초창기 버전 보는 느낌이다. 많이 발전했다고 해야 할지... 아직 멀었다고 해야 할지... 음... gemma 2 2b 버전 나왔다는데 이걸로 봐꿔봐야 하나? 

 

쨋든, 2b 모델로는 힘들다는 것을 알았다. 튜닝이 필요한 시간이다. 다음시간에는 훈련을 위한 데이터셋 준비 과정을 알아보자.