<aside> 💡 Onnx로 Model Conversion 진행 계기
</aside>
# Onnx Runtime(ORT)설치
## CPU version
pip install onnxruntime
## GPU version
pip install onnxruntime-gpu
# onnx를 onnx runtime으로 실행시키는 코드
class Inference:
def __init__(self) -> None:
self.storage_client = storage.Client()
bucket = self.storage_client.bucket(BUCKET_NAME)
newest_blob=None
for blob in bucket.list_blobs():
if newest_blob==None:
newest_blob=blob
else:
if newest_blob.time_created<blob.time_created:
newest_blob=blob
contents = blob.download_as_string()
self.logger.info(f"Downloaded ONNX from {BUCKET_NAME} as {newest_blob.name}")
self.session=self.__create_session(contents)
self.blob_name=newest_blob.name
def __create_session(self,model: str) -> ort.InferenceSession:
return ort.InferenceSession(model)
def run(self,x):
out=self.session.run(None, {'input': x})
self.logger.info(out[0][0].tolist())
return out[0][0]
원인:
0에 가까운 가중치를 가질 때 ONNX에서 inference 시간이 20~30배 정도 늘어나는 문제가 있었다. 이는 지수부 제한으로 인해 정규화되지 못한 비정규값들 때문에 cpu연산 속도가 느려지는 일이었다.
해결방법:
torch.set_flush_denormal(True)
를 실행해 주면 이런 0과 가까운 가중치들을 다 0으로 만들어서 원래의 속도를 다시 재현할 수 있었다
# Examples
torch.save(model.state_dict(), f"{save_dir}/{data}_{config.model}_best_epoch{epoch}_{macro_f1_score:6.4}.pth")
torch.set_flush_denormal(True)
torch.onnx.export(model, dummy_input, f"{save_dir}/{data}_{config.model}_best_{macro_f1_score:6.4}.onnx", export_params=True, opset_version=11,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
best_macro_f1_score = macro_f1_score
torch.set_flush_denormal(False)
Reference
원인:
Timm 라이브러리에서 가져온 Pretrained efficientnet에서 activation funcion이 SiLU로 설정되어 있는데 Onnx에서 이를 지원하지 않음
해결방법:
모델의 activaton function을 Relu로 변경해서 Onnx Conversion이 가능하도록 변경
# Examples
class EfficientNetB0(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.efficientnet = timm.create_model('efficientnet_b0', pretrained = True, num_classes = num_classes, drop_rate=0.5, act_layer = nn.ReLU)
def forward(self, x):
x = self.efficientnet(x)
return x