def compare_state_dict(dict1, dict2): # compare keys for key in dict1: if key not in dict2: return False for key in dict2: if key not in dict1: return False for (k,v) in dict1.items(): if not torch.all(torch.isclose(v, dict2[k])) return False return True
Comparing two pytorch dicts