BERT 개념이해 #10 ONNX 변환 #
#2026-03-04
#1 ONNX 변환이 필요한 이유: 학습용 모델을 “서빙용 실행파일”로 바꾸고 싶다
학습이 끝난 BERT 모델은 보통 PyTorch 가중치 파일(.safetensors)과 파이썬 코드 형태로 존재한다. 이 상태로 서버에 올리면 두 가지 현실 문제가 바로 생긴다. 첫째, PyTorch 자체가 무겁다. 서버에 PyTorch와 그 주변 의존성을 깔아야 하고, 환경이 커지고 운영 부담이 커진다. 둘째, 추론을 할 때마다 파이썬 인터프리터 경로를 타게 된다. 결국 “모델 연산은 C++/CUDA로 하더라도, 전체 실행 흐름은 Python이 감싼다”는 구조가 남아서 오버헤드가 생긴다.
여기서 우리가 원하는 건 명확하다. 학습이 끝난 모델을 “어디서든 똑같이 실행되는 범용 형식”으로 내보내고, 그 형식을 최적화된 런타임이 빠르게 돌리게 만들고 싶다. ONNX는 그 목적을 위해 만들어진 표준 교환 형식이다.
#
#2 ONNX의 정체: 모델을 파이썬 객체가 아니라 “계산 그래프”로 저장한다
PyTorch 모델은 사람 입장에서는 편하다. 클래스가 있고, forward가 있고, 조건문도 쓸 수 있고, 코드로 표현이 된다. 하지만 서버 입장에서는 불편하다. 서버는 “코드를 실행하는 것”보다 “이미 정해진 계산 그래프를 최대한 빠르게 돌리는 것”이 더 잘 맞는다.
ONNX는 모델을 코드가 아니라 그래프로 만든다. 그래프에는 노드들이 있고, 노드는 MatMul, Add, LayerNorm, Softmax 같은 기본 연산들이다. BERT를 ONNX로 내보내면 “Transformer 12층”이 파이썬 클래스 형태로 저장되는 게 아니라, 수백 개의 연산 노드로 풀어헤쳐진 그래프가 파일에 직렬화된다. 그래서 ONNX 파일은 “이 모델은 이렇게 계산한다”라는 실행 계획서에 가깝다. 이 계획서를 onnxruntime이 읽으면, 파이썬 없이도 같은 계산을 수행할 수 있다.
PyTorch 모델 (Python 객체):
BertForSequenceClassification
└─ BertModel (12 Transformer Layers)
└─ EnhancedClassifier (768→256→4)
↓ torch.onnx.export()
ONNX 그래프 (bert_enhanced.onnx):
node: MatMul ← Linear(768, 768)
node: Add ← + bias
node: LayerNorm
node: MatMul ← Q/K/V projection
... (수백 개의 기본 연산 노드)
node: Softmax ← 최종 분류
저장: data/models_enhanced/bert_enhanced.onnx
크기: ~418 MB
torch.onnx.export()는 모델을 단순히 저장하는 게 아니라, “이 모델이 실제로 어떤 연산을 수행하는지”를 알아내기 위해 한 번 실행해본다. 그래서 더미 입력이 필요하다. 더미 입력을 모델에 넣고 forward가 어떤 연산들을 호출하는지 추적해서, 그 경로를 ONNX 그래프로 기록한다.
여기서 중요한 감각은 이거다. ONNX 변환은 “가중치만 떼어내는 작업”이 아니라 “가중치 + 연산 흐름”을 통째로 굳혀서 그래프로 만드는 작업이다. 그래서 변환이 끝나면 .onnx 파일 하나만으로도 모델의 구조와 파라미터가 함께 들어간 형태가 된다.
[서빙 시]
bert_enhanced.onnx
↓
onnxruntime.InferenceSession()
→ 그래프 최적화 (연산 융합, 상수 접기)
→ CPU/GPU 커널 선택
↓
sess.run(None, {"input_ids": ..., "attention_mask": ...})
↓
logits (numpy array)
더미 입력을 만들 때 흔히 padding=“max_length”, max_length=192로 만들면 입력 shape이 (1, 192)가 된다. 문제는 ONNX가 이 shape을 “이 모델 입력은 항상 (1, 192)다”라고 고정해버릴 수 있다는 것이다. 그러면 실제 서빙에서 배치 크기가 32가 되거나, dynamic padding으로 길이가 64가 되는 순간 모델이 shape mismatch로 터진다.
그래서 dynamic_axes를 지정해줘야 한다. “0번 축은 배치 크기라서 늘어날 수 있다”, “1번 축은 시퀀스 길이라서 달라질 수 있다”를 명시적으로 선언하는 것이다. 이 선언이 들어가면 ONNX 그래프는 입력 텐서의 0번과 1번 축을 고정 숫자가 아니라 ‘심볼릭 차원’으로 취급하게 되고, 어떤 배치/길이가 들어와도 실행 가능해진다. 즉 dynamic_axes는 “학습 때의 고정 텐서”를 “서빙에 필요한 가변 텐서”로 바꿔주는 안전장치다.
ONNX 변환 시 더미 입력을 받는다:
dummy: tokenizer("Sample text", max_length=192)
→ input_ids.shape = (1, 192) ← 배치 1, 시퀀스 192
dynamic_axes 없이 변환하면:
모든 입력이 (1, 192)로 고정됨
→ 배치 32, 시퀀스 64로 실행하면 에러!
dynamic_axes 지정하면:
"input_ids": {0: "batch_size", 1: "sequence"}
→ 배치 차원(0)과 시퀀스 차원(1) 모두 가변
→ 어떤 배치 크기, 어떤 시퀀스 길이도 OK
서버에서 onnxruntime에 입력을 줄 때는 딕셔너리 형태로 준다. “input_ids”: …, “attention_mask”: …처럼 이름으로 매칭한다. 그런데 export할 때 input_names를 안 주면, PyTorch가 ONNX 내부 입력 이름을 자동으로 만들어버리는 경우가 있다. 그러면 서버는 “input_ids"를 넣었는데 그래프에는 “onnx::Gather_0” 같은 이름만 있어서 “그 키가 없다”는 에러가 난다.
또 한 가지 자주 걸리는 함정이 HuggingFace BERT의 token_type_ids다. BERT는 원래 문장쌍을 구분하기 위해 token_type_ids를 받을 수 있다. 그런데 더미 입력에 token_type_ids를 포함하지 않았는데 export 과정에서 모델이 기대하는 입력이 달라지거나 이름이 꼬이면, 추론 시 입력 매칭이 어긋날 수 있다. 그래서 가장 안전한 방식은 지금 너가 적어둔 것처럼, export에서 input_names=[“input_ids”, “attention_mask”]를 명확히 지정하고, 더미 입력 튜플의 순서도 그와 1:1로 맞추는 것이다. 그러면 “서버가 넣는 키”와 “그래프가 기대하는 키”가 확실히 고정된다.
ONNX는 “이런 연산 노드들이 존재한다”는 표준 집합을 버전으로 관리한다. 그게 opset_version이다. 버전이 너무 낮으면 BERT에서 필요한 연산(특히 LayerNorm을 구성하는 패턴이나 특정 연산 조합)이 제대로 표현되지 않거나, 런타임이 지원하지 않아 변환/실행이 깨질 수 있다. 반대로 버전이 너무 최신이면 실행 환경의 런타임이 아직 그 opset을 완벽히 지원하지 못할 수도 있다. 그래서 보통 “충분히 최신이면서 호환이 좋은 버전”을 선택한다. 너의 코드에서 opset 14를 쓴 이유도 그 현실적인 균형점 때문이다.
ONNX 파일이 생성됐다고 해서 끝이 아니다. 그래프 구조가 깨졌거나, 타입이 맞지 않거나, 노드 연결이 이상해도 파일은 생길 수 있다. 그래서 onnx.checker.check_model()로 그래프의 기본 정합성을 검사한다. 이 단계는 “내보내기 성공”이 아니라 “내보낸 그래프가 ONNX 규격상 올바르다”를 확인하는 단계다. 실제 서빙에서 터지기 전에 미리 잡는 안전벨트라고 보면 된다.
#
#3 서빙에서의 실행 흐름: 세션을 한 번 만들고, 그 세션에 numpy 입력만 던진다
PyTorch는 보통 요청마다 모델 forward를 호출하는 구조가 되기 쉬운데, ONNX Runtime은 다르게 쓴다. 서버가 시작할 때 InferenceSession을 한 번 만들어서 그래프를 로딩한다. 이때 런타임은 내부적으로 최적화 패스를 돌린다. 연산을 합칠 수 있으면 합치고, 상수를 접을 수 있으면 접고, 메모리 배치를 더 효율적으로 바꾼다. 그 다음부터 요청이 들어오면 sess.run()에 입력을 넣고 결과를 받기만 하면 된다.
여기서 또 하나의 실용 포인트가 있다. onnxruntime은 입력을 보통 numpy로 받는다. 그래서 토크나이저가 만든 PyTorch 텐서를 .numpy()로 바꿔서 넣는다. 그러면 출력으로 logits가 numpy array로 나오고, 그 다음 후처리(softmax, argmax)를 해서 라벨과 confidence를 만든다. 즉 서버에서 “파이썬 모델 실행”이 아니라 “최적화된 그래프 실행”만 남는다.
#
#4 ONNX가 빠른 이유: 런타임이 그래프 전체를 보고 ‘합칠 건 합친다’
PyTorch는 eager 실행 모델이라, 연산이 파이썬 레벨에서 순서대로 호출되는 느낌이 강하다. 반면 ONNX Runtime은 그래프 전체를 한 번에 보고 최적화를 할 수 있다. 예를 들어 LayerNorm은 수학적으로는 평균 계산, 분산 계산, 정규화, 스케일/시프트 같은 여러 연산의 조합인데, 런타임은 이 패턴을 알아보고 “LayerNorm fused kernel” 하나로 합쳐서 실행할 수 있다. 연산을 합치면 커널 호출 횟수가 줄고, 메모리 접근이 줄고, CPU 캐시 효율도 올라간다. 그래서 같은 모델이라도 ONNX가 더 빠르게 나오는 경우가 많다. 너의 벤치마크에서 CPU에서 1.5~3배 향상이 나온 이유는 이런 그래프 레벨 최적화의 결과라고 보면 된다.
PyTorch 실행 경로:
Python 코드 → PyTorch C++ 연산 → 각 연산 개별 실행
ONNX Runtime 실행 경로:
그래프 로딩 → 최적화 패스 실행:
① 연산 융합 (LayerNorm = Element-wise ops 3개 → 1개 커널)
② 상수 접기 (bias 등 고정값 미리 계산)
③ 메모리 레이아웃 최적화
→ 최적화된 커널로 실행
특히:
LayerNorm: PyTorch는 μ 계산, (x-μ) 계산, σ² 계산, 정규화를 별도 연산
ONNX Runtime은 fused LayerNorm 커널로 1번에 처리
실무에서는 모델이 새로 학습될 때마다 서빙 포맷도 같이 갱신되어야 한다. 그래서 retrain 파이프라인에서 champion이 승격될 때만 export_onnx를 돌리고, 그 결과물을 MLflow에 아티팩트로 저장하는 구조는 굉장히 합리적이다. 이 흐름은 한마디로 “훈련 산출물(가중치)과 운영 산출물(ONNX)을 동시에 관리한다”는 뜻이다. 이렇게 해두면 운영 서버는 항상 ONNX 파일만 가져가서 실행하면 되고, PyTorch 환경을 서빙 서버에 유지할 필요가 줄어든다.
#
#cf 프로덕션 확장: ONNX + 하드웨어 가속
환경별 권장 구성:
[x86 CPU 서버]
onnxruntime (CPUExecutionProvider)
→ 현재 구성, P50≈25ms
[NVIDIA GPU]
onnxruntime-gpu (CUDAExecutionProvider)
→ CUDA 최적화 커널, P50≈5ms 수준
[Apple Silicon]
onnxruntime (CoreMLExecutionProvider)
→ Neural Engine 활용, P50≈3ms 수준
[엣지 배포]
ONNX → TensorFlow Lite 또는 OpenVINO 변환
→ 모바일/임베디드 환경
#
#5 정리: ONNX 변환은 ‘학습 모델’을 ‘운영 가능한 실행 그래프’로 바꾸는 과정이다
ONNX는 PyTorch 모델을 범용 계산 그래프로 직렬화해, 파이썬과 PyTorch 없이도 최적화된 런타임에서 실행할 수 있게 만든다. export는 더미 입력으로 연산 경로를 추적해 그래프를 만들고, dynamic_axes로 배치/길이를 가변으로 풀어 서빙 환경에서 깨지지 않게 한다. input_names를 명시해 입력 매칭 문제를 막고, opset_version으로 필요한 연산 지원을 확보한다. 변환된 ONNX는 onnxruntime이 로딩 시 최적화하고, 추론 시에는 numpy 입력을 받아 빠르게 logits를 내놓는다. 결국 ONNX 변환은 “모델을 더 빠르고, 더 가볍고, 더 이식성 있게” 만드는 프로덕션용 마무리 단계라고 보면 된다.
- ONNX = PyTorch 계산 그래프를 범용 형식으로 직렬화
torch.onnx.export(): 더미 입력으로 그래프를 추적(trace)하여.onnx파일 생성dynamic_axes: 배치 크기 / 시퀀스 길이를 가변으로 허용 (필수!)input_names명시: token_type_ids 이름 불일치 버그 방지opset_version=14: BERT 연산(LayerNorm 등) 지원하는 최소 버전onnxruntime: ONNX 파일을 최적화하여 실행하는 런타임 (PyTorch 불필요)- 속도 향상: CPU에서 PyTorch 대비 1.5~3x 빠름 (연산 융합, 커널 최적화)
- 이 프로젝트:
bert_enhanced.onnx→ FastAPI/predict엔드포인트에서 기본 추론 엔진