Source code for alf.test.case

# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple wrapper over unittest.TestCase to provide extra functionality."""

import numpy as np
import torch
import unittest
import alf
from alf.utils import common


[docs]class TestCase(unittest.TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) self.addTypeEqualityFunc(torch.Tensor, 'assertTensorEqual')
[docs] def setUp(self): # create_environment() might have been used in other tests. # We need to reset_configs() so we can avoid the error of configuring # a function after it is used. alf.reset_configs() # Some test may create a globel env. We need to close it. alf.close_env() common.set_random_seed(1) alf.summary.reset_global_counter()
[docs] def assertTensorEqual(self, t1, t2, msg=None): self.assertIsInstance(t1, torch.Tensor, 'First argument is not a Tensor') self.assertIsInstance(t2, torch.Tensor, 'Second argument is not a Tensor') if not torch.all(t1.cpu() == t2.cpu()): standardMsg = '%s != %s' % (t1, t2) self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertTensorClose(self, t1, t2, epsilon=1e-6, msg=None): self.assertIsInstance(t1, torch.Tensor, 'First argument is not a Tensor') self.assertIsInstance(t2, torch.Tensor, 'Second argument is not a Tensor') diff = torch.max(torch.abs(t1 - t2)) if not (diff <= epsilon): standardMsg = '%s is not close to %s. diff=%s' % (t1, t2, diff) self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertTensorNotClose(self, t1, t2, epsilon=1e-6, msg=None): self.assertIsInstance(t1, torch.Tensor, 'First argument is not a Tensor') self.assertIsInstance(t2, torch.Tensor, 'Second argument is not a Tensor') if torch.max(torch.abs(t1 - t2)) < epsilon: standardMsg = '%s is actually close to %s' % (t1, t2) self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertArrayEqual(self, t1, t2, msg=None): t1 = np.array(t1) t2 = np.array(t2) self.assertEqual(t1.shape, t2.shape, msg=msg) if not np.all(t1 == t2): standardMsg = '%s != %s' % (t1, t2) self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertArrayClose(self, t1, t2, epsilon=1e-6, msg=None): t1 = np.array(t1) t2 = np.array(t2) self.assertEqual(t1.shape, t2.shape, msg=msg) diff = np.max(np.abs(t1 - t2)) if not (diff <= epsilon): standardMsg = '%s is not close to %s. diff=%s' % (t1, t2, diff) self.fail(self._formatMessage(msg, standardMsg))
[docs] def assertArrayNotClose(self, t1, t2, epsilon=1e-6, msg=None): t1 = np.array(t1) t2 = np.array(t2) self.assertEqual(t1.shape, t2.shape, msg=msg) if np.max(np.abs(t1 - t2)) < epsilon: standardMsg = '%s is actually close to %s' % (t1, t2) self.fail(self._formatMessage(msg, standardMsg))