Updates slbm
This commit is contained in:
		| @@ -131,7 +131,7 @@ def main(args): | |||||||
|  |  | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) |         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||||
|         future_x.to(args.device), future_y.to(args.device) |         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|         future_y_hat = model(future_x) |         future_y_hat = model(future_x) | ||||||
|         future_loss = criterion(future_y_hat, future_y) |         future_loss = criterion(future_y_hat, future_y) | ||||||
|         metric(future_y_hat, future_y) |         metric(future_y_hat, future_y) | ||||||
|   | |||||||
| @@ -130,7 +130,7 @@ def main(args): | |||||||
|  |  | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) |         xmetric = ComposeMetric(metric_cls(True), SaveMetric()) | ||||||
|         future_x.to(args.device), future_y.to(args.device) |         future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|         future_y_hat = model(future_x) |         future_y_hat = model(future_x) | ||||||
|         future_loss = criterion(future_y_hat, future_y) |         future_loss = criterion(future_y_hat, future_y) | ||||||
|         metric(future_y_hat, future_y) |         metric(future_y_hat, future_y) | ||||||
|   | |||||||
| @@ -22,11 +22,11 @@ class TestQuadraticSFunc(unittest.TestCase): | |||||||
|     """Test the quadratic function.""" |     """Test the quadratic function.""" | ||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         function = QuadraticSFunc([[0, 1], [0.5, 4], [1, 1]]) |         function = QuadraticSFunc({0: 1, 1: 2, 2: 1}) | ||||||
|         print(function) |         print(function) | ||||||
|         for x in (0, 0.5, 1): |         for x in (0, 0.5, 1): | ||||||
|             print("f({:})={:}".format(x, function(x))) |             print("f({:})={:}".format(x, function(x))) | ||||||
|         thresh = 0.2 |         thresh = 1e-7 | ||||||
|         self.assertTrue(abs(function(0) - 1) < thresh) |         self.assertTrue(abs(function(0) - 1) < thresh) | ||||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) |         self.assertTrue(abs(function(0.5) - 0.5 * 0.5 - 2 * 0.5 - 1) < thresh) | ||||||
|         self.assertTrue(abs(function(1) - 1) < thresh) |         self.assertTrue(abs(function(1) - 1 - 2 - 1) < thresh) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user