Update Q workflow
This commit is contained in:
		| @@ -41,8 +41,8 @@ class QuantTransformer(Model): | ||||
|   def __init__( | ||||
|     self, | ||||
|     d_feat=6, | ||||
|     hidden_size=64, | ||||
|     num_layers=2, | ||||
|     hidden_size=48, | ||||
|     depth=5, | ||||
|     dropout=0.0, | ||||
|     n_epochs=200, | ||||
|     lr=0.001, | ||||
| @@ -62,7 +62,7 @@ class QuantTransformer(Model): | ||||
|     # set hyper-parameters. | ||||
|     self.d_feat = d_feat | ||||
|     self.hidden_size = hidden_size | ||||
|     self.num_layers = num_layers | ||||
|     self.depth = depth | ||||
|     self.dropout = dropout | ||||
|     self.n_epochs = n_epochs | ||||
|     self.lr = lr | ||||
| @@ -79,7 +79,7 @@ class QuantTransformer(Model): | ||||
|       "Transformer parameters setting:" | ||||
|       "\nd_feat : {}" | ||||
|       "\nhidden_size : {}" | ||||
|       "\nnum_layers : {}" | ||||
|       "\ndepth : {}" | ||||
|       "\ndropout : {}" | ||||
|       "\nn_epochs : {}" | ||||
|       "\nlr : {}" | ||||
| @@ -93,7 +93,7 @@ class QuantTransformer(Model): | ||||
|       "\nseed : {}".format( | ||||
|         d_feat, | ||||
|         hidden_size, | ||||
|         num_layers, | ||||
|         depth, | ||||
|         dropout, | ||||
|         n_epochs, | ||||
|         lr, | ||||
| @@ -112,7 +112,9 @@ class QuantTransformer(Model): | ||||
|       np.random.seed(self.seed) | ||||
|       torch.manual_seed(self.seed) | ||||
|  | ||||
|     self.model = TransformerModel(d_feat=self.d_feat) | ||||
|     self.model = TransformerModel(d_feat=self.d_feat, | ||||
|                                   embed_dim=self.hidden_size, | ||||
|                                   depth=self.depth) | ||||
|     self.logger.info('model: {:}'.format(self.model)) | ||||
|     self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model))) | ||||
|    | ||||
|   | ||||
		Reference in New Issue
	
	Block a user