| | import os |
| | import sys |
| | if __name__ == "__main__": |
| | from train import * |
| | else: |
| | from .train import * |
| |
|
| |
|
| |
|
| |
|
| | try: |
| | test_item = sys.argv[1] |
| | except IndexError: |
| | assert __name__ == "__main__" |
| | test_item = "./checkpoint_test" |
| | test_items = [] |
| | if os.path.isdir(test_item): |
| | for item in os.listdir(test_item): |
| | item = os.path.join(test_item, item) |
| | test_items.append(item) |
| | elif os.path.isfile(test_item): |
| | test_items.append(test_item) |
| |
|
| |
|
| |
|
| |
|
| | for item in test_items: |
| | state = torch.load(item, map_location="cpu") |
| | model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()}) |
| | loss, acc, all_targets, all_predicts = test(model=model) |