티스토리 뷰

반응형

TensorRT-LLM 프레임워크는 일반적인 모델을 TensorRT Engine으로 빌드 하도록 지원해준다. Engine으로 빌드하는 것 뿐만 아니라 다양한 양자화 기법들을 제공한다. GPU에 서빙을 위해 자신만의 엔진으로 최적화 한다고 생각하면 된다. 자세한 내용은 깃헙에 들어가면 많은 설명이 있다. (개인적으로 헷갈리는 부분이 많았다.)

TensorRT-LLM 서빙 헤커톤에 참여하여 A100 GPU를 가지고 모델 서빙과 양자화 등 여러 경험을 할 수 있었다. 대상 모델로 DeepSeek 모델을 선정하였고 이를 실제로 TensorRT 엔진으로 변환하고 서빙해보기로 하였다.

1. 환경 설정

모델을 가지고 이것 저것 해볼때 오류를 많이 접하게 되는데 라이브러리의 버전 이슈로 인하여 발생한 오류가 굉장히 많다. 그래서 TRT-LLM을 사용할 환경은 대부분 가이드라인을 따라하는게 좋다. 그렇지 않으면 삽질의 연속이다. 엔비디아에서 제공해주는 Triton inference 도커 컨테이너를 사용하여 환경 세팅을 진행해보았다.

Docker run

docker run -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --gpus=all --volume ${PWD}:/workspace nvcr.io/nvidia/tritonserver:24.02-trtllm-python-py3

TensorRT-LLM clone

git clone https://github.com/NVIDIA/TensorRT-LLM.git

Install requirements.txt

TensorRT-LLM/examples/llama 경로에 있는 requirements.txt를 통해 필요한 패키지를 설치하자.

pip install -r requirements.txt

2. Convert

먼저 HuggingFace에서 다운받은 DeepSeek 모델을 TRT-LLM Checkpoint 로 변환해보자. DeepSeek 모델은 기존 LLama 모델과 구조와 동일하므로 TRT-LLM 예시에 나와있는 방식을 따라하였고 필요한 파라미터를 변경했다. 여기서는 int8로 양자화를 진행했다.

python convert_checkpoint.py --model_dir {모델 경로} \
                              --output_dir {TRT checkpoint 저장 경로} \
                              --dtype float16 \
                              --use_weight_only \
                              --weight_only_precision int8

이렇게 완료하면 output_dir 경로에 config.json, rank0.safetensors두가지 파일이 생성된다. config.json 파일은 전체적인 모델 아키텍처의 메타 정보가 나와있다. rank0.safetensors에는 실제 파라미터의 웨이트 정보가 있다.

부가적으로 float32, float16모델과 int8로 양자화 모델의 사이즈를 보았는데 많이 차이가 나는 것을 확인할 수 있다.

13G	    float16
26G	    float32
6.6G	float16_to_int8

3. Build

TRT-LLM checkpoint로 만든 모델로 이제 TRT-LLM 엔진으로 빌드 해보자. gemm_plugin은 현재 모델의 dtype으로 정하였다. 자세한 사용 이유는 따로 조사가 필요해보인다.

trtllm-build --checkpoint_dir {Convert한 모델 경로} \
            --output_dir {TRT-LLM 엔진 저장 경로} \
            --gemm_plugin float16

build를 하고나면 config.json, rank0.engine 파일이 생성된다. convert 단계에서 생성된 config.json 파일은 build시에 설정한 옵션 정보가 추가된다.

build_config 일부 예시

 51   "build_config": {
 52         "max_input_len": 1024,
 53         "max_output_len": 1024,
 54         "max_batch_size": 1,
 55         "max_beam_width": 1,
 56         "max_num_tokens": 1024,
 57         "opt_num_tokens": 1,
 58         "max_prompt_embedding_table_size": 0,
 59         "gather_context_logits": false,
 60         "gather_generation_logits": false,
 61         "strongly_typed": false,
 62         "builder_opt": null,
 63         "profiling_verbosity": "layer_names_only",
 64         "enable_debug_output": false,
 65         "max_draft_len": 0,
 66         "use_refit": false,
 67         "input_timing_cache": null,
 68         "output_timing_cache": "model.cache",
 69         "lora_config": {
 70             "lora_dir": [],
 71             "lora_ckpt_source": "hf",
 72             "max_lora_rank": 64,
 73             "lora_target_modules": [],
 74             "trtllm_modules_to_hf_modules": {}
 75         },

4. Run

TRT-LLM engine을 활용하여 run을 해보자. examples/run.py 파일을 사용하여 인퍼런스를 진행해볼 수 있다.

python3 examples/run.py 
   --engine_dir={build한 엔진 경로} 
   --max_output_len 128
   --tokenizer_dir {모델 토크나이저 경로} 
   --input_text "{입력할 프롬프트}"

5. 결과

deepseek-coder 깃헙에 나와있는 가이드대로 진행하여 결과를 비교해보았다. int8, float16 결과는 완전 똑같이 나왔다.

--input_text "#write a quick sort algorithm"

int8 양자화 출력 결과

def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[0]
        less = [i for i in arr[1:] if i <= pivot]
        greater = [i for i in arr[1:] if i > pivot]
        return quick_sort(less) + [pivot] + quick_sort(greater)

#test the function
print(quick_sort([3,6,8,10,1,2,1]))
float16 출력 결과

def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[0]
        less = [i for i in arr[1:] if i <= pivot]
        greater = [i for i in arr[1:] if i > pivot]
        return quick_sort(less) + [pivot] + quick_sort(greater)

#test the function
print(quick_sort([3,6,8,10,1,2,1]))
반응형
반응형
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/02   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28
글 보관함