RNN的 hidden state
```py
class rnn_(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.rnn = torch.nn.RNN(input_size, hidden_size,
num_layers, batch_first=True)
def forward(self, x):
h = torch.zeros(1, x.size(0), self.hidden_size)
out, h = self.rnn(x, h)
return out
```
不在那個位置寫,經常會出莫名其妙的錯!
*****
手寫 dataset 的問題
在__init__中完成input和 target 的張量,在__getitem__中只做取值操作。數據的shape是**(總量, 其它)**,其它例如圖片可能是(channel, height, width),minist數據是(28, 28),文字數據是(序列長度)等。
```py
class qohdataset(data.Dataset):
"""
Dataset must define __getitem__ and __len__
"""
def __init__(self, qoh):
def padding(ele, num):
difference = num - len(ele)
for _ in range(difference):
ele.append(np.zeros((47,)))
self.qoh = qoh
for i in self.qoh:
if len(i) < 13:
padding(i, 13)
self.qoh = np.array(self.qoh, dtype=int)
print(self.qoh.shape)
self.qoh = torch.from_numpy(self.qoh)
self.seq = self.qoh[:, 0:12, :]
self.tar = self.qoh[:, 1:13, :]
def __getitem__(self, index):
"""
index位置的(x, y), x和y都是tensor
Returns one data pair (x and y).
"""
x = self.seq[index, ...]
y = self.tar[index, ...]
return x, y
def __len__(self):
# 0<= index < lens
lens = self.qoh.shape[1]
return lens
```
*****
dataloader的問題
dataloader獲得的是(batch, 其它),其它和 dataset 一致。一般而言,只有在輸入序列不一樣長的時候才會定義collate_fn,否則直接調用即可
*****
數據類型是有要求的:
float, double, half, short(int16), int(int32), long(int64)