들어가며
Python 코드를 읽다 보면 가끔 이런 decorator를 만난다.
from numba import njit
@njit
def some_fast_function(x):
...
처음 보면 그냥 “빠르게 해주는 마법”처럼 보인다. 그런데 실제로는 마법이라기보다, 꽤 명확한 trade-off가 있는 도구이다.
이 글에서는 Numba JIT를 다음 관점으로 정리해보려고 한다.
- Numba가 무엇인가.
- JIT compilation이 무엇인가.
@jit,@njit가 함수에 무슨 일을 하는가.- 왜 어떤 코드는 빨라지고 어떤 코드는 별로 안 빨라지는가.
- UMAP, SHAP, librosa 같은 유명 Python project에서는 어디에 Numba를 쓰는가.
Numba란 무엇인가?
Numba는 Python 함수를 빠른 machine code로 바꿔주는 JIT compiler이다. (이름은 NumPy + Mamba에서 왔다. 맘바(mamba)는 빠른 뱀이고, 파이썬(python)은 프로그래밍 언어 이름이면서 비단뱀이라는 뜻도 있으니, “Python code를 맘바처럼 빠르게 만든다”는 말장난이다.)
특히 NumPy array, scalar, numerical loop처럼 type과 memory layout이 비교적 분명한 코드에서 잘 동작한다.
조금 쉽게 말하면, Numba는 Python 전체 실행 환경을 바꿔서 모든 코드를 빠르게 만드는 도구가 아니다. 대신 내가 지정한 함수 하나를 보고, “이 함수는 숫자 연산 위주니까 Python interpreter를 매번 거치지 않아도 되겠네”라고 판단되면, 그 함수를 CPU가 바로 실행할 수 있는 code로 compile해준다.
그래서 Numba를 쓸 때는 보통 이런 식으로 함수 위에 decorator를 붙인다.
from numba import njit
@njit
def sum_square(x):
acc = 0.0
for i in range(x.shape[0]):
acc += x[i] * x[i]
return acc
공식 문서에서는 Numba를 Python용 Just-In-Time compiler라고 설명하고, NumPy array, NumPy function, loop가 있는 코드에 가장 잘 맞는다고 설명한다. Anaconda 쪽 설명에서는 “NumPy-aware optimizing compiler”라는 표현도 쓴다.
그러니 이름은 이렇게 기억하는 정도면 충분하다.
Numba = NumPy/numerical Python code를 빠르게 compile해주는 JIT compiler
Just-In-Time (JIT) Compilation이란?
JIT는 Just-In-Time compilation의 약자이다. 말 그대로 “실행 직전에 compile한다”는 뜻이다.
보통 Python 함수는 Python interpreter가 한 줄씩 해석하면서 실행한다. 이 방식은 유연하지만, loop가 많고 숫자 연산이 촘촘한 코드에서는 overhead가 크다.
Numba는 특정 Python 함수를 가져와서, 그 함수가 실제로 호출될 때 입력 type을 보고 machine code로 compile한다.
대략 흐름은 이렇다.
Python function
-> @njit decorator
-> 첫 호출에서 argument type 확인
-> Numba compiler가 machine code 생성
-> 이후 같은 type으로 호출하면 compiled code 실행
그래서 첫 호출은 느릴 수 있다. compile 비용을 내기 때문이다. 하지만 그 다음 호출부터는 Python interpreter를 덜 거치고 compiled code를 실행하므로 빨라질 수 있다.
Numba의 핵심 역할
Numba의 역할은 다음처럼 요약할 수 있다.
NumPy array와 scalar를 다루는 Python numeric function을 LLVM 기반 machine code로 JIT compile해주는 도구.
여기서 LLVM은 원래 Low Level Virtual Machine이라는 이름에서 출발했지만,
현재 공식적으로는 acronym이라기보다 LLVM compiler infrastructure project의 이름으로 쓰인다.
여기서 중요한 단어는 numeric function이다. Numba는 모든 Python을 빠르게 해주는 도구가 아니다. 숫자 계산, array loop, 수치 알고리즘처럼 type이 어느 정도 고정되고 반복 연산이 많은 코드에서 힘을 낸다.
예를 들어 이런 코드는 Numba가 좋아한다.
import numpy as np
from numba import njit
@njit
def sum_square(x):
acc = 0.0
for i in range(x.shape[0]):
acc += x[i] * x[i]
return acc
이 함수는 입력이 float64[:] 같은 NumPy array라고 추론되면,
loop 전체를 compiled code로 바꿀 수 있다.
반대로 이런 코드는 Numba와 잘 맞지 않는다.
@njit
def not_good(xs):
out = []
for x in xs:
out.append({"value": x})
return out
Python object, dictionary, dynamic list, string 처리, file I/O, pandas object 같은 것이 많이 섞이면 Numba가 compile하기 어렵다.
입력/출력 파이프라인
먼저 Numba compiler가 Python function과 실제 argument type을 입력으로 받아 dispatcher, typed IR, machine code를 만드는 흐름을 살펴보자. 처음 읽을 때 헷갈리는 질문은 “Numba가 Python 함수를 C로 바꾸는가?”이다.
정확히는 C source code를 만들어서 저장하는 것은 아니다. Numba는 Python function을 분석하고, type을 추론한 뒤, LLVM을 거쳐 machine code를 만든다.
여기서 IR, typed IR, LLVM 같은 단어가 갑자기 나오면 조금 어렵다.
간단히만 잡고 가자.
- IR: Intermediate Representation의 약자이다. source code와 machine code 사이에 있는 compiler 내부용 중간 표현이다.
- typed IR: 각 변수와 연산의 type 정보가 붙은 IR이다. 예를 들어
x[i]가 그냥 Python object인지,float64array의 원소인지가 정해진 상태라고 보면 된다. - LLVM: 여러 compiler가 공통 backend처럼 사용할 수 있는 compiler infrastructure이다. 여기서는 최적화와 machine code 생성을 담당한다고 이해하면 된다.
- lowering: 더 높은 수준의 표현을 더 낮은 수준의 표현으로 바꾸는 과정이다. 여기서는 Numba가 만든 IR을 LLVM이 이해할 수 있는 형태와 machine code에 가까운 형태로 낮추는 단계라고 보면 된다.
조금 더 쉬운 흐름으로 쓰면 다음과 같다.
Python code
-> Numba IR
-> typed IR
-> LLVM IR
-> machine code
즉 Numba는 Python source를 바로 CPU 명령어로 바꾸는 것이 아니라, 중간 표현을 여러 단계 거치면서 type을 확정하고, LLVM이 잘 최적화할 수 있는 형태로 낮춘 뒤, 최종적으로 실행 가능한 machine code를 만든다.
모듈별로 보면 다음과 같다.
| 단계 | 입력 | 출력 | 역할 |
|---|---|---|---|
| Python function | def f(x): ... |
Python callable | 원래 코드 |
@jit / @njit |
Python function | dispatcher | 호출 type별 compiled version 관리 |
| type inference | 실제 argument type | typed IR | Python object 없이 돌 수 있는지 판단 |
| LLVM lowering | typed IR | machine code | CPU에서 직접 실행할 code 생성 |
| runtime call | NumPy array / scalar | return value | compiled function 실행 |
사용자 입장에서 보이는 것은 그냥 함수 호출이다.
y = sum_square(x)
하지만 내부적으로는 첫 호출에서 compile이 일어나고,
그 뒤 같은 type의 x가 들어오면 이미 만든 compiled version을 재사용한다.
@jit와 @njit
Numba에서 가장 많이 보는 decorator는 두 개이다.
from numba import jit, njit
요즘 코드를 읽을 때는 보통 @njit를 먼저 이해하면 된다.
njit는 nopython=True라는 뜻으로 보면 된다.
@njit
def f(x):
...
이 말은 “이 함수 안에서는 Python interpreter 도움 없이 compile 가능한 code만 쓰겠다”에 가깝다.
nopython이라는 이름이 붙은 이유는 Python 문법으로 코드를 쓰지만,
실행할 때는 Python object를 최대한 끊고 native type으로 돌리기 때문이다.
중요한 차이는 이것이다.
- nopython mode: Python interpreter 없이 compiled code로 실행된다. 빠르다.
- object mode: Python object를 많이 유지한다. fallback 성격이 강하고 성능 이점이 작을 수 있다.
최근 Numba 문서에서는 @jit의 기본도 nopython mode 쪽으로 설명한다.
그래도 코드 읽기에서는 @njit가 더 의도가 분명하다.
@njit가 보이면 “이 함수는 Python object 없이 숫자 loop를 compiled code로 돌리려는 hot path구나”라고 읽으면 된다.
Numba가 빨라지는 이유
Python loop가 느린 이유 중 하나는 매 iteration마다 Python interpreter overhead가 들어가기 때문이다.
예를 들어 다음 코드를 생각해보자.
for i in range(n):
y[i] = a * x[i] + b
사람 눈에는 단순한 곱셈과 덧셈이다. 하지만 Python interpreter 입장에서는 매번 다음을 확인해야 한다.
i는 어떤 object인가.x[i]indexing은 어떤 method를 호출하는가.a * x[i]는 어떤 type의 곱셈인가.y[i] = ...는 어떤 set operation인가.
Numba는 compile 시점에 type을 고정할 수 있으면 이 overhead를 줄인다.
x가 float64 array이고 a, b가 float이라는 것을 알면,
loop를 CPU가 바로 실행할 수 있는 code로 낮출 수 있다.
그래서 Numba는 이런 경우에 특히 잘 맞는다.
| 코드 형태 | Numba와의 궁합 |
|---|---|
| NumPy array를 돌면서 scalar 연산 | 좋음 |
| 작은 function이 엄청 많이 호출됨 | 좋음 |
| 조건문이 있는 numerical loop | 좋음 |
| Python list/dict/object 중심 코드 | 나쁨 |
| pandas DataFrame method chaining | 나쁨 |
| 이미 큰 NumPy/SciPy vectorized op 한두 번만 호출 | 애매함 |
| GPU tensor library, 예를 들어 PyTorch op 조합 | 보통 Numba 대상이 아님 |
첫 호출이 느린 이유
Numba를 처음 써보면 이런 상황을 볼 수 있다.
sum_square(x) # 첫 호출은 느림
sum_square(x) # 두 번째부터 빠름
이건 정상이다. 첫 호출에서는 Numba가 argument type을 보고 compile한다. 그 뒤에는 같은 type signature에 대해 compiled version을 재사용한다.
그래서 benchmark할 때는 첫 호출을 빼고 재야 한다.
sum_square(x) # warm-up, compile cost 포함
%timeit sum_square(x) # compiled code 실행 시간 측정
cache=True를 쓰면 compiled result를 disk cache에 저장해서 다음 process에서 compile 비용을 줄일 수 있다.
다만 cache는 file 위치, 환경, Python/Numba version에 민감할 수 있다.
parallel=True와 prange
Numba는 단순히 single-thread loop만 빠르게 하는 것이 아니다. 조건이 맞으면 loop를 multi-thread로 parallelize할 수도 있다.
from numba import njit, prange
@njit(parallel=True)
def row_sum(X):
out = np.zeros(X.shape[0])
for i in prange(X.shape[0]):
s = 0.0
for j in range(X.shape[1]):
s += X[i, j]
out[i] = s
return out
여기서 prange는 parallel range이다.
각 row 계산이 서로 독립이라면 thread로 나눠 실행할 수 있다.
하지만 아무 loop나 parallelize되는 것은 아니다. iteration 사이 dependency가 있으면 조심해야 한다.
실제로 어느 정도 빨라지는지 궁금해서 repo 밖의 /private/tmp/numba_prange_benchmark.py에서 간단히 재봤다.
첫 호출은 compile warm-up으로 버리고, 같은 float64 matrix에 대해 serial range version과 parallel prange version의 median runtime을 비교했다.
환경은 `Python 3.12.13, NumPy 2.0.2, Numba 0.65.1, Numba thread 12개이다.
| shape | data | serial median | parallel median | speedup |
|---|---|---|---|---|
| 512 x 256 | 1.0 MiB | 0.049 ms | 0.038 ms | 1.29x |
| 2,048 x 256 | 4.0 MiB | 0.193 ms | 0.086 ms | 2.23x |
| 8,192 x 256 | 16.0 MiB | 0.777 ms | 0.242 ms | 3.21x |
| 32,768 x 256 | 64.0 MiB | 3.179 ms | 0.825 ms | 3.85x |
| 131,072 x 256 | 256.0 MiB | 12.768 ms | 2.745 ms | 4.65x |
이 결과를 보면 작은 input에서는 thread를 나누는 overhead 때문에 이득이 작고, row 수가 커질수록 parallel execution의 이득이 커진다. 다만 12 threads라고 해서 12배 빨라지는 것은 아니다. 이 예제는 결국 memory에서 값을 읽고 더하는 작업이라 memory bandwidth의 영향을 많이 받기 때문이다.
실제 유명 코드 예시
이제 실제 project에서는 Numba를 어디에 쓰는지 보자. 공통점은 모두 “전체 application을 Numba로 compile”하지 않는다는 점이다. 대신 병목이 되는 작은 numerical kernel에만 붙인다.
1. UMAP: neighbor graph 계산의 hot loop
UMAP은 dimensionality reduction에서 많이 쓰이는 library이다. UMAP은 nearest neighbor graph를 만들고, fuzzy simplicial set이라는 graph weight를 계산한다.
umap/umap_.py를 보면 이런 함수들이 Numba로 compile된다.
smooth_knn_distcompute_membership_strengthsfast_intersectionfast_metric_intersection
특히 smooth_knn_dist는 @numba.njit(..., parallel=True)로 되어 있다.
각 sample마다 nearest neighbor distance를 보고 적절한 local scale을 찾는 함수이다.
이걸 Python으로만 돌리면 sample 수와 neighbor 수에 따라 loop가 커진다. 그래서 UMAP은 이 부분을 Numba로 빼서 빠르게 만든다.
mental model은 이렇다.
knn distances
-> 각 row마다 binary search
-> local scale sigma/rho 계산
-> graph membership strength 계산
-> sparse graph 생성
이건 Numba가 좋아하는 형태이다.
- 입력이 NumPy array이다.
- loop가 많다.
- row별 계산이 비교적 독립적이다.
- Python object를 많이 만들 필요가 없다.
그래서 UMAP 코드에서 @numba.njit(parallel=True)를 보면,
“graph construction 단계의 numerical loop를 compiled kernel로 만든 것”이라고 읽으면 된다.
2. SHAP: mask와 clustering 관련 반복 연산
SHAP은 model explanation에서 많이 쓰이는 library이다. SHAP은 feature subset mask를 만들고, mask 순서를 바꾸고, clustering tree를 따라가며 index를 채우는 작업이 많다.
shap/utils/_clustering.py에는 from numba import njit가 있고,
다음 같은 helper들이 @njit로 compile된다.
_pt_shuffle_recdelta_minimization_order_reverse_window_mask_delta_score
shap/utils/_masked_model.py에도 _build_fixed_single_output, _build_fixed_multi_output, _init_masks, _rec_fill_masks 같은 함수들이 @njit로 compile된다.
이 함수들은 딥러닝 model 자체를 compile하는 것이 아니다. 대신 explanation 과정에서 반복적으로 등장하는 mask manipulation, index fill, output aggregation 같은 작은 병목을 줄인다.
mental model은 이렇다.
mask / index array
-> 반복적으로 swap, reverse, fill
-> model output을 batch 위치에 맞게 누적
-> Python loop overhead를 줄이기 위해 njit 적용
여기서도 Numba를 쓰는 이유는 분명하다. mask는 결국 boolean/integer array이고, 작은 loop가 많이 돈다. 이런 코드는 Python으로 쓰면 읽기 쉽지만 느릴 수 있고, Numba를 붙이면 C extension을 직접 쓰지 않고도 속도를 끌어올릴 수 있다.
3. librosa: audio utility의 dense array kernel
librosa는 audio/music analysis에서 많이 쓰이는 library이다.
librosa 문서의 librosa.util.utils source를 보면 __shear_dense라는 dense array helper가 @numba.jit(nopython=True, cache=True)로 compile된다.
이 함수는 dense matrix의 column들을 일정한 규칙으로 roll해서, lag representation과 recurrence representation 사이를 바꾸는 데 쓰인다.
대략 구조는 이렇다.
dense matrix X
-> column index i마다
-> X[:, i]를 factor * i만큼 roll
-> sheared matrix 반환
이것도 Numba가 좋아하는 형태이다.
- 입력은 dense NumPy array이다.
- column loop가 있다.
- 각 column에 대해 단순한 array operation을 반복한다.
cache=True로 compile result 재사용을 의도한다.
librosa 전체가 Numba로 되어 있는 것은 아니다. 오디오 I/O, high-level feature API, validation logic은 여전히 Python/NumPy/SciPy 중심이다. Numba는 그중 반복적인 dense array utility에 선택적으로 붙어 있다.
C++ 대신 Numba를 쓰는 이유
Numba의 장점은 Python 코드 형태를 크게 유지하면서 hot loop만 빠르게 만들 수 있다는 점이다.
C++ extension이나 Cython을 쓰면 더 세밀한 제어가 가능하지만, 빌드 시스템, compiler, wheel 배포, platform compatibility 부담이 생긴다.
Numba는 그 중간 지점에 있다.
pure Python
-> 가장 쉽지만 loop가 느릴 수 있음
NumPy vectorization
-> 빠르지만 복잡한 control flow를 표현하기 어려울 수 있음
Numba
-> Python loop를 유지하면서 numerical hot path를 compile
C++ / Cython / Rust extension
-> 더 강력하지만 개발과 배포 부담이 큼
그래서 research code나 scientific Python library에서 Numba를 자주 볼 수 있다. 알고리즘을 Python으로 읽을 수 있게 유지하면서, 진짜 느린 loop만 compiled code로 바꾸기 좋기 때문이다.
Numba를 읽을 때 보는 순서
코드에서 @njit를 만나면 나는 보통 다음 순서로 본다.
- 입력 type이 NumPy array인지 본다.
- 함수 안에 큰 loop가 있는지 본다.
- Python object, list, dict, class instance가 많은지 본다.
- 첫 호출 compile cost가 문제가 될 위치인지 본다.
parallel=True,prange,cache=True,fastmath=True같은 option을 확인한다.- 이 함수가 전체 pipeline에서 hot path인지 확인한다.
특히 중요한 질문은 이것이다.
이 함수는 왜 NumPy vectorization만으로 충분하지 않았을까?
대개 답은 둘 중 하나이다.
- loop 안에 condition이나 custom logic이 많다.
- intermediate array를 크게 만들지 않고 scalar loop로 처리하고 싶다.
자주 하는 오해
1. Numba를 붙이면 모든 Python이 빨라진다
아니다. Numba는 Python의 일부 numerical subset에 강하다. 일반 web code, file parsing, pandas-heavy pipeline, class object 중심 code에는 보통 맞지 않는다.
2. NumPy code에는 Numba가 항상 필요하다
아니다.
이미 np.dot, scipy.linalg, np.fft처럼 내부가 최적화된 native library를 한 번 호출하는 코드라면 Numba를 붙여도 큰 이득이 없을 수 있다.
3. 첫 benchmark가 느리면 Numba가 느린 것이다
꼭 그렇지는 않다. 첫 호출에는 compile 비용이 들어간다. warm-up 후 실행 시간을 따로 봐야 한다.
4. Numba는 PyTorch나 JAX와 같은 역할이다
아니다. Numba는 일반 Python numerical function을 compile하는 도구에 가깝다. PyTorch/JAX는 tensor computation graph, autograd, accelerator backend까지 포함하는 더 큰 framework이다.
Numba를 쓰기 좋은 경우
Numba는 다음 상황에서 먼저 떠올려볼 만하다.
- Python loop가 profiler에서 병목으로 잡힌다.
- loop 안 연산이 숫자/NumPy scalar 중심이다.
- vectorization으로 바꾸면 코드가 너무 복잡해진다.
- C++ extension까지 만들기에는 부담스럽다.
- 같은 function을 같은 type으로 여러 번 호출한다.
반대로 다음 상황이면 조심하는 게 좋다.
- 함수가 한 번만 호출된다.
- 데이터가 작아서 compile overhead가 더 크다.
- pandas object나 Python dict/list가 핵심이다.
- 이미 BLAS/LAPACK/FFT 같은 native library call이 대부분이다.
- 배포 환경에서 Numba/llvmlite version 관리가 부담스럽다.
정리
Numba JIT는 Python을 통째로 빠르게 만드는 버튼이 아니다. 더 정확히는,
NumPy array와 scalar 중심의 numerical hot loop를 Python 문법으로 유지한 채, 실행 시점에 machine code로 바꿔주는 도구
라고 보는 게 좋다.
그래서 유명한 project들도 Numba를 application 전체에 바르지 않는다. UMAP은 graph construction의 loop에, SHAP은 mask/index manipulation에, librosa는 dense array utility에 선택적으로 쓴다.
코드에서 @njit를 만나면 이렇게 읽으면 된다.
여기는 Python으로 쓰면 읽기 쉽지만,
반복 횟수가 많아서 interpreter overhead가 병목이 되는 numerical kernel이구나.
References
- Numba documentation - Compiling Python code with @jit
- Numba documentation - A 5 minute guide to Numba
- Numba documentation - FAQ: Where does the project name Numba come from?
- numba/numba GitHub README
- Anaconda.org - numba
- LLVM Compiler Infrastructure Project
- UMAP source -
umap/umap_.py - SHAP source -
shap/utils/_clustering.py - SHAP source -
shap/utils/_masked_model.py - librosa source docs -
librosa.util.utils