Source code for AIModels.LocalInformer

'''
Modified version of the InformerForPrediction class
===================================================

This modules contains routines modified from the original InformerForPrediction class
from the Huggingface library. The modifications are made to allow for the use of the
InformerModel class in the training and prediction of time series data with weights and
deterministic loss functions. The original Informer prediction classes can be found at
`https://huggingface.co/docs/transformers/main/model_doc/informer`.

Main Modifications:
-------------------

1. The InformerForPrediction class has been modified to allow for the use of deterministic
loss functions and weights in the training of time series data. The modifications are made
to the forward method of the class. The original forward method is replaced with a new
forward method that allows for the use of deterministic loss functions and weights in the
training of time series data. The new forward method also allows for the use of the
InformerModel class in the training and prediction of time series data with weights and
deterministic loss functions.

2. The generate method of the InformerForPrediction class has been modified to allow for
the use of deterministic loss functions and weights in the training of time series data. The
modifications are made to the generate method of the class. The original generate method is
replaced with a new generate method that allows for the use of deterministic loss functions
and weights in the training of time series data.

3. The output_distribution method of the InformerForPrediction class has been modified to
allow for the use of deterministic loss functions and weights in the training of time series
data. The modifications are made to the output_distribution method of the class. The original
output_distribution method is replaced with a new output_distribution method that allows for
the use of deterministic loss functions and weights in the training of time series data.

4. The InformerForPrediction class has been modified to allow for the use of deterministic
loss functions and weights in the training of time series data. The modifications are made
to the forward method of the class. The original forward method is replaced with a new
forward method that allows for the use of deterministic loss functions and weights in the
training of time series data. The new forward method also allows for the use of the
InformerModel class in the training and prediction of time series data with weights and
deterministic loss functions.

Utilities 
---------

'''

import transformers as tr
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Union, Tuple
from transformers.modeling_outputs import Seq2SeqTSModelOutput, SampleTSPredictionOutput,Seq2SeqTSPredictionOutput
from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
# from transformers.models.informer.modeling_informer import InformerModel, INFORMER_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.informer.modeling_informer import INFORMER_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.informer.configuration_informer import InformerConfig
from transformers.models.informer.modeling_informer import InformerPreTrainedModel
from transformers.models.informer.modeling_informer import StudentTOutput, NormalOutput, NegativeBinomialOutput
from transformers.models.informer.modeling_informer import nll, weighted_average
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.informer.modeling_informer import InformerMeanScaler, InformerStdScaler, InformerNOPScaler
from transformers.models.informer.modeling_informer import InformerFeatureEmbedder
from transformers.models.informer.modeling_informer import InformerEncoder, InformerDecoder


[docs] class InformerModel(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) if config.scaling == "mean" or config.scaling is True: self.scaler = InformerMeanScaler(dim=1, keepdim=True) elif config.scaling == "std": self.scaler = InformerStdScaler(dim=1, keepdim=True) else: self.scaler = InformerNOPScaler(dim=1, keepdim=True) if config.num_static_categorical_features > 0: self.embedder = InformerFeatureEmbedder( cardinalities=config.cardinality, embedding_dims=config.embedding_dimension, ) # transformer encoder-decoder and mask initializer self.encoder = InformerEncoder(config) self.decoder = InformerDecoder(config) # Initialize weights and apply final processing self.post_init() @property def _past_length(self) -> int: return self.config.context_length + max(self.config.lags_sequence)
[docs] def get_lagged_subsequences( self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0 ) -> torch.Tensor: """ Returns lagged subsequences of a given sequence PARAMETERS ---------- sequence: Tensor The sequence from which lagged subsequences should be extracted. Shape: (N, T, C). subsequences_length : int Length of the subsequences to be extracted. shift: int Shift the lags by this amount back. RETURNS ------- Tensor: Returns a tensor of shape (N, S, C, I), where S = subsequences_length and I = len(indices), containing lagged subsequences. Specifically, lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :]. """ sequence_length = sequence.shape[1] indices = [lag - shift for lag in self.config.lags_sequence] if max(indices) + subsequences_length > sequence_length: raise ValueError( f"lags cannot go further than history length, found lag {max(indices)} " f"while history length is only {sequence_length}" ) lagged_values = [] for lag_index in indices: begin_index = -lag_index - subsequences_length end_index = -lag_index if lag_index > 0 else None lagged_values.append(sequence[:, begin_index:end_index, ...]) return torch.stack(lagged_values, dim=-1)
[docs] def create_network_inputs( self, past_values: torch.Tensor, past_time_features: torch.Tensor, static_categorical_features: Optional[torch.Tensor] = None, static_real_features: Optional[torch.Tensor] = None, past_observed_mask: Optional[torch.Tensor] = None, future_values: Optional[torch.Tensor] = None, future_time_features: Optional[torch.Tensor] = None, ): # time feature time_feat = ( torch.cat( ( past_time_features[:, self._past_length - self.config.context_length :, ...], future_time_features, ), dim=1, ) if future_values is not None else past_time_features[:, self._past_length - self.config.context_length :, ...] ) # target if past_observed_mask is None: past_observed_mask = torch.ones_like(past_values) context = past_values[:, -self.config.context_length :] observed_context = past_observed_mask[:, -self.config.context_length :] _, loc, scale = self.scaler(context, observed_context) #Allow for different input and output features if future_values is not None: if(self.config.input_size != self.config.output_size): batch_size, seq_len, _ = future_values.shape kaz = torch.zeros(batch_size,seq_len,self.config.input_size,device=torch.device('mps')) kaz[...,:self.config.output_size] = future_values[...,:self.config.output_size] else: kaz = future_values inputs = ( (torch.cat((past_values, kaz), dim=1) - loc) / scale if future_values is not None else (past_values - loc) / scale ) # static features log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p() log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat((log_abs_loc, log_scale), dim=1) if static_real_features is not None: static_feat = torch.cat((static_real_features, static_feat), dim=1) if static_categorical_features is not None: embedded_cat = self.embedder(static_categorical_features) static_feat = torch.cat((embedded_cat, static_feat), dim=1) expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1) # all features features = torch.cat((expanded_static_feat, time_feat), dim=-1) # lagged features subsequences_length = ( self.config.context_length + self.config.prediction_length if future_values is not None else self.config.context_length ) lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length) lags_shape = lagged_sequence.shape reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]: raise ValueError( f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match" ) # transformer inputs transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, loc, scale, static_feat
[docs] def get_encoder(self): return self.encoder
[docs] def get_decoder(self): return self.decoder
# @add_start_docstrings_to_model_forward(INFORMER_INPUTS_DOCSTRING) # @replace_return_docstrings(output_type=Seq2SeqTSModelOutput, config_class=_CONFIG_FOR_DOC)
[docs] def forward( self, past_values: torch.Tensor, past_time_features: torch.Tensor, past_observed_mask: torch.Tensor, static_categorical_features: Optional[torch.Tensor] = None, static_real_features: Optional[torch.Tensor] = None, future_values: Optional[torch.Tensor] = None, future_time_features: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" python: >>> from huggingface_hub import hf_hub_download >>> import torch >>> from transformers import InformerModel >>> file = hf_hub_download( ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" ... ) >>> batch = torch.load(file) >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") >>> # during training, one provides both past and future values >>> # as well as possible additional features >>> outputs = model( ... past_values=batch["past_values"], ... past_time_features=batch["past_time_features"], ... past_observed_mask=batch["past_observed_mask"], ... static_categorical_features=batch["static_categorical_features"], ... static_real_features=batch["static_real_features"], ... future_values=batch["future_values"], ... future_time_features=batch["future_time_features"], ... ) >>> last_hidden_state = outputs.last_hidden_state """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_inputs, loc, scale, static_feat = self.create_network_inputs( past_values=past_values, past_time_features=past_time_features, past_observed_mask=past_observed_mask, static_categorical_features=static_categorical_features, static_real_features=static_real_features, future_values=future_values, future_time_features=future_time_features, ) if encoder_outputs is None: enc_input = transformer_inputs[:, : self.config.context_length, ...] encoder_outputs = self.encoder( inputs_embeds=enc_input, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) dec_input = transformer_inputs[:, self.config.context_length :, ...] decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs + (loc, scale, static_feat) return Seq2SeqTSModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, loc=loc, scale=scale, static_features=static_feat, )
[docs] class InformerForPrediction(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) self.model = InformerModel(config) if config.distribution_output == "student_t": self.distribution_output = StudentTOutput(dim=config.output_size) elif config.distribution_output == "normal": self.distribution_output = NormalOutput(dim=config.output_size) elif config.distribution_output == "negative_binomial": self.distribution_output = NegativeBinomialOutput(dim=config.output_size) else: raise ValueError(f"Unknown distribution output {config.distribution_output}") self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) self.target_shape = self.distribution_output.event_shape if config.loss == "nll": self.loss = nll else: raise ValueError(f"Unknown loss function {config.loss}") # Initialize weights of distribution_output and apply final processing self.post_init()
[docs] def output_params(self, dec_output): return self.parameter_projection(dec_output)
[docs] def get_encoder(self): return self.model.get_encoder()
[docs] def get_decoder(self): return self.model.get_decoder()
[docs] @torch.jit.ignore def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution: sliced_params = params if trailing_n is not None: sliced_params = [p[:, -trailing_n:] for p in params] return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale)
# @add_start_docstrings_to_model_forward(INFORMER_INPUTS_DOCSTRING) # @replace_return_docstrings(output_type=Seq2SeqTSModelOutput, config_class=_CONFIG_FOR_DOC)
[docs] def forward( self, past_values: torch.Tensor, past_time_features: torch.Tensor, past_observed_mask: torch.Tensor, static_categorical_features: Optional[torch.Tensor] = None, static_real_features: Optional[torch.Tensor] = None, future_values: Optional[torch.Tensor] = None, future_time_features: Optional[torch.Tensor] = None, future_observed_mask: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" Returns: Examples: python: >>> from huggingface_hub import hf_hub_download >>> import torch >>> from transformers import InformerForPrediction >>> file = hf_hub_download( ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" ... ) >>> batch = torch.load(file) >>> model = InformerForPrediction.from_pretrained("huggingface/informer-tourism-monthly") >>> # during training, one provides both past and future values >>> # as well as possible additional features >>> outputs = model( ... past_values=batch["past_values"], ... past_time_features=batch["past_time_features"], ... past_observed_mask=batch["past_observed_mask"], ... static_categorical_features=batch["static_categorical_features"], ... static_real_features=batch["static_real_features"], ... future_values=batch["future_values"], ... future_time_features=batch["future_time_features"], ... ) >>> loss = outputs.loss >>> loss.backward() >>> # during inference, one only provides past values >>> # as well as possible additional features >>> # the model autoregressively generates future values >>> outputs = model.generate( ... past_values=batch["past_values"], ... past_time_features=batch["past_time_features"], ... past_observed_mask=batch["past_observed_mask"], ... static_categorical_features=batch["static_categorical_features"], ... static_real_features=batch["static_real_features"], ... future_time_features=batch["future_time_features"], ... ) >>> mean_prediction = outputs.sequences.mean(dim=1) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if future_values is not None: use_cache = False outputs = self.model( past_values=past_values, past_time_features=past_time_features, past_observed_mask=past_observed_mask, static_categorical_features=static_categorical_features, static_real_features=static_real_features, future_values=future_values, future_time_features=future_time_features, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, ) # print(f'Shape of outputs is {outputs[0].shape}, keys are {outputs.keys()}') prediction_loss = None params = None # print(f'Shape of output is {outputs[0].shape}') if future_values is not None: params = self.output_params(outputs[0]) # outputs.last_hidden_state # print(f'Shape of params is {params[0].shape}') # print(f'Shapeof loc is {outputs[-3].shape}, scale is {outputs[-2].shape}') # loc is 3rd last and scale is 2nd last output distribution = self.output_distribution(params, loc=outputs[-3][:,:,:self.config.output_size], scale=outputs[-2][:,:,:self.config.output_size]) # print(f'Shape of distribution is {distribution.loc.shape} {distribution}') loss = self.loss(distribution, future_values) # print(f'future_values shape is {future_values.shape}, loss shape is {loss.shape}') if future_observed_mask is None: future_observed_mask = torch.ones_like(future_values) if len(self.target_shape) == 0: loss_weights = future_observed_mask else: loss_weights, _ = future_observed_mask.min(dim=-1, keepdim=False) # print(f'loss_weights shape is {loss_weights.shape}') # # Generate the initial sequence if self.config.loss_weight is not None: # bsize,tsize = loss_weights.shape # # initial_sequence = (torch.arange(tsize+1, 1, -1), dtype=torch.float32, device='mps')**32 # initial_sequence = (torch.zeros((bsize,tsize), dtype=torch.float32, device='mps')) # initial_sequence[:,0] = 1 loss_weights = self.config.loss_weight # Normalize the weights so they sum to 1 # loss_weights = initial_sequence / initial_sequence.sum().repeat(bsize,1) # print(f'W is {loss_weights}') prediction_loss = weighted_average(loss, weights=loss_weights) if not return_dict: outputs = ((params,) + outputs[1:]) if params is not None else outputs[1:] return ((prediction_loss,) + outputs) if prediction_loss is not None else outputs return Seq2SeqTSPredictionOutput( loss=prediction_loss, params=params, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, loc=outputs.loc, scale=outputs.scale, static_features=outputs.static_features, )
[docs] @torch.no_grad() def generate( self, past_values: torch.Tensor, past_time_features: torch.Tensor, future_time_features: torch.Tensor, past_observed_mask: Optional[torch.Tensor] = None, static_categorical_features: Optional[torch.Tensor] = None, static_real_features: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> SampleTSPredictionOutput: r""" Greedily generate sequences of sample predictions from a model with a probability distribution head. Parameters: past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): Past values of the time series, that serve as context in order to predict the future. The sequence size of this tensor must be larger than the `context_length` of the model, since the model will use the larger size to construct lag features, i.e. additional values from the past which are added in order to serve as "extra context". The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of the past. The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as `static_categorical_features`, `static_real_features`, `past_time_features` and lags). Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of variates in the time series per time step. past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): Required time features, which the model internally will add to `past_values`. These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These could also be so-called "age" features, which basically help the model know "at which point in life" a time-series is. Age features have small values for distant past time steps and increase monotonically the more we approach the current time step. Holiday features are also a good example of time features. These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where the position encodings are learned from scratch internally as parameters of the model, the Time Series Transformer requires to provide additional time features. The Time Series Transformer only learns additional embeddings for `static_categorical_features`. Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features must but known at prediction time. The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): Required time features for the prediction window, which the model internally will add to sampled predictions. These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These could also be so-called "age" features, which basically help the model know "at which point in life" a time-series is. Age features have small values for distant past time steps and increase monotonically the more we approach the current time step. Holiday features are also a good example of time features. These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where the position encodings are learned from scratch internally as parameters of the model, the Time Series Transformer requires to provide additional time features. The Time Series Transformer only learns additional embeddings for `static_categorical_features`. Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features must but known at prediction time. The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in `[0, 1]`: - 1 for values that are **observed**, - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series. Static categorical features are features which have the same value for all time steps (static over time). A typical example of a static categorical feature is a time series ID. static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): Optional static real features which the model will add to the values of the time series. Static real features are features which have the same value for all time steps (static over time). A typical example of a static real feature is promotion information. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. Return: [`SampleTSPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of samples, prediction_length)` or `(batch_size, number of samples, prediction_length, input_size)` for multivariate predictions. """ outputs = self( static_categorical_features=static_categorical_features, static_real_features=static_real_features, past_time_features=past_time_features, past_values=past_values, past_observed_mask=past_observed_mask, future_time_features=future_time_features, future_values=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, use_cache=True, ) decoder = self.model.get_decoder() enc_last_hidden = outputs.encoder_last_hidden_state loc = outputs.loc scale = outputs.scale static_feat = outputs.static_features num_parallel_samples = self.config.num_parallel_samples repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0) repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0) repeated_past_values = ( past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc ) / repeated_scale expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1) features = torch.cat((expanded_static_feat, future_time_features), dim=-1) repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0) repeated_enc_last_hidden = enc_last_hidden.repeat_interleave(repeats=num_parallel_samples, dim=0) future_samples = [] # greedy decoding for k in range(self.config.prediction_length): lagged_sequence = self.model.get_lagged_subsequences( sequence=repeated_past_values, subsequences_length=1 + k, shift=1, ) lags_shape = lagged_sequence.shape # print(f'lags_shape is {lags_shape}') reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1) # print(f'repeated_past_values shape is {repeated_past_values.shape}, reshaped_lagged_sequence shape is {reshaped_lagged_sequence.shape}') decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1) dec_output = decoder(inputs_embeds=decoder_input, encoder_hidden_states=repeated_enc_last_hidden) dec_last_hidden = dec_output.last_hidden_state params = self.parameter_projection(dec_last_hidden[:, -1:]) distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale) # print(f'params shape is {params[0].shape}, loc shape is {repeated_loc.shape}, scale shape is {repeated_scale.shape}') # print(f'distr shape is {distr.loc.shape}') next_sample = distr.sample() # print(f'next_sample shape is {next_sample.shape}') # Make sure input values have all the same length repeated_past_values = torch.cat( (repeated_past_values[:,:,:], (next_sample - repeated_loc) / repeated_scale), dim=1 ) future_samples.append(next_sample) concat_future_samples = torch.cat(future_samples, dim=1) return SampleTSPredictionOutput( sequences=concat_future_samples.reshape( (-1, num_parallel_samples, self.config.prediction_length) + self.target_shape, ) )