Documentation and formatting updates #1
|
@ -12,8 +12,8 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||||
Convert a PyTorch model to ONNX format.
|
Convert a PyTorch model to ONNX format.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
output_path (str): The path to save the converted ONNX model.
|
ckpt_path (str): The path to save the converted ONNX model.
|
||||||
fp16 (bool): 16 float point percision
|
fp16 (bool): 16 float point precision
|
||||||
"""
|
"""
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
|
|
||||||
|
@ -68,8 +68,6 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
settings = get_app_settings()
|
|
||||||
|
|
||||||
model_checkpoint = "inference_recognition_model.ckpt"
|
model_checkpoint = "inference_recognition_model.ckpt"
|
||||||
|
|
||||||
print("Converting to ONNX...")
|
print("Converting to ONNX...")
|
||||||
|
|
|
@ -5,8 +5,6 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
|
||||||
|
|
||||||
|
|
||||||
def profile_onnx_model(
|
def profile_onnx_model(
|
||||||
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
|
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
|
||||||
|
@ -86,6 +84,5 @@ def profile_onnx_model(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
settings = get_app_settings()
|
|
||||||
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
|
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
|
||||||
profile_onnx_model(output_path)
|
profile_onnx_model(output_path)
|
||||||
|
|
|
@ -50,8 +50,6 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
||||||
)
|
)
|
||||||
|
|
||||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||||||
sample = first_rec
|
|
||||||
shape, dtype = sample.shape, sample.dtype
|
|
||||||
|
|
||||||
with h5py.File(output_path, "w") as hf:
|
with h5py.File(output_path, "w") as hf:
|
||||||
data_arr = np.stack([rec[0] for rec in records])
|
data_arr = np.stack([rec[0] for rec in records])
|
||||||
|
|
|
@ -24,11 +24,9 @@ class SqueezeExcite(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs,
|
||||||
se_ratio=0.25,
|
|
||||||
reduced_base_chs=None,
|
reduced_base_chs=None,
|
||||||
act_layer=nn.SiLU,
|
act_layer=nn.SiLU,
|
||||||
gate_fn=torch.sigmoid,
|
gate_fn=torch.sigmoid,
|
||||||
divisor=1,
|
|
||||||
**_,
|
**_,
|
||||||
):
|
):
|
||||||
super(SqueezeExcite, self).__init__()
|
super(SqueezeExcite, self).__init__()
|
||||||
|
@ -77,13 +75,6 @@ class GBN(torch.nn.Module):
|
||||||
self.act = act
|
self.act = act
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
|
|
||||||
# res = [self.bn(x_) for x_ in chunks]
|
|
||||||
# return self.drop(self.act(torch.cat(res, dim=0)))
|
|
||||||
# x = self.bn(x)
|
|
||||||
# x = self.act(x)
|
|
||||||
# x = self.drop(x)
|
|
||||||
# return x
|
|
||||||
return self.drop(self.act(self.bn(x)))
|
return self.drop(self.act(self.bn(x)))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -142,5 +142,4 @@ def plot_confusion_matrix_with_counts(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||||
evaluate_checkpoint(ckpt_path)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user