Practical Deep Learning for Coders Course - Lesson 3

This blog-post series captures my weekly notes while I attend the fastaiv5 course conducted by University of Queensland with fast.ai. So off to week3 where we learn more how neural networks and say the C word while learning neural networks …
fastai
fastaicourse
Author

Kurian Benoy

Published

May 10, 2022

Lesson Setup

There was a minor delay in streaming the lessons today, as today the sessions where being conducted in-person by Jeremy at University of Queensland. There were 130 people watching live in youtube.

Jeremy started the lesson by saying that usually lesson 1 and 2 are easy for everyone, while it’s usually from lesson 3 things start getting hard. There is also a lesson 0 on how to do fast.ai?

I had the previously written about fastai lesson 0, where Jeremy mentioned about How to do fast.ai lesson through the following five steps:

  1. Watching lecture/book (watching the video first without trying anything)
  2. Running notebook and experimentation (going through lesson notebooks and experimenting stuff)
  3. Reproduce results (try with fastai clean notebook version, see if you are able to understand and do things on your own)
  4. Working on a different dataset (play with a different dataset, paraticipate in kaggle …)

Always studying done with other people is the best way to retain your knowledge. So it’s great to participate in study groups like Delft-fastai sessions.

This week, Jeremy showcased the various students projects based on those who got highest number of votes in share your work here topic in fastai forums. My work also got featured 🙂 in the lesson.

image

Dogs vs Cat notebooks- which image models are the best?

Today Jeremy featured, paperspace gradient platform. He has been using it for his development and it’s totally amazing. He got something done by them to update fastbook regularly.

Important

In lesson2 the main things, is not about taking a particular platform and deploying them through javascript websites or online applications. But the key thing is to undestand the concept. There are two pieces:

Note
  1. The Training piece by end of which you get a model.pkl file. Once you got that (train.ipynb)
Note
  1. Then part which takes inputs, spits out output … this separate step is deploying (app.ipynb)

Finding good image models, by baselines results along with inference time will help us choose good architecture. He tried levit_models, which didn’t work really great.

From [13:52] in the video, he experiments with convnext tiny models from timm library. It got really good accuracy with almost 0.05 loss. At the moment for computer vision there are lot of good architectures, which beats resnets really well. In this case for predicting 37 breeds of dogs we can find categories in dataset using vocab of dataloaders in model.

labels = model.dls.vocab

It’s very important to understand what’s in a model? Using get_submodule in pytorch we can look at the various neural networks, what is their input and ouput and each layers. [21:24]

Let’s explore the architecture of a translation model which translates from their target form to english.

from transformers import AutoModelForSeq2SeqLM


model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-mul-en")

Looking at model architecture

# collapse_output
model
MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(64172, 512, padding_idx=64171)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(64172, 512, padding_idx=64171)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (decoder): MarianDecoder(
      (embed_tokens): Embedding(64172, 512, padding_idx=64171)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (activation_fn): SiLUActivation()
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=512, out_features=64172, bias=False)
)

Looking at layer 1 self_attn_layer_norm

# collapse_output
attention_layer = model.get_submodule("model.encoder.layers.0.self_attn_layer_norm")
list(attention_layer.parameters())
[Parameter containing:
 tensor([0.3865, 0.6348, 0.6938, 0.7140, 1.1017, 1.0888, 0.7801, 0.7572, 0.7402,
         0.5655, 0.5940, 0.7477, 0.6920, 0.6781, 0.5128, 0.5862, 0.7173, 0.5140,
         0.5940, 0.5998, 0.5002, 0.5931, 0.3720, 0.8686, 0.6557, 0.7436, 0.7564,
         0.5402, 0.6773, 0.6831, 0.7060, 0.8484, 0.8874, 0.9380, 0.7360, 0.6073,
         0.7911, 0.6247, 0.6225, 0.7281, 0.7470, 0.8066, 0.6336, 0.5607, 0.6914,
         0.7630, 1.0365, 0.5133, 0.8260, 0.9167, 0.6362, 0.6375, 0.7296, 1.0838,
         0.7916, 0.8332, 1.0474, 0.9655, 0.9446, 0.8361, 0.9928, 0.7550, 0.8335,
         0.9597, 0.3449, 0.6119, 0.9266, 0.8208, 0.7301, 0.9969, 0.4639, 0.6579,
         1.0493, 0.9808, 0.9181, 0.7736, 0.7346, 0.9642, 1.2211, 1.3974, 1.3712,
         1.4836, 1.2050, 1.1015, 1.3986, 1.4113, 1.3771, 1.5623, 1.5389, 1.0727,
         1.5310, 1.3641, 1.5365, 1.4774, 1.4893, 1.4168, 1.5904, 1.5720, 1.3812,
         1.5914, 1.5096, 1.2807, 0.1877, 1.3947, 1.6565, 1.2572, 1.7532, 1.7136,
         1.5001, 1.7059, 1.6033, 1.5448, 1.5357, 1.5565, 1.5366, 1.3784, 1.6677,
         1.6570, 1.6885, 1.6925, 1.5795, 1.6837, 1.7601, 1.6240, 1.8309, 1.6668,
         1.7021, 1.7827, 1.8194, 1.8531, 1.9633, 1.7518, 1.9518, 1.8846, 2.0106,
         1.9608, 1.8964, 1.9245, 0.0996, 1.8191, 1.8534, 1.7096, 1.7831, 0.1533,
         2.0808, 1.8960, 2.1153, 1.8570, 2.0739, 2.1022, 2.0319, 1.3613, 1.9232,
         2.1441, 2.0704, 2.1557, 2.1526, 2.2401, 2.0910, 1.8356, 2.1069, 1.7451,
         0.1487, 2.1800, 2.1589, 2.0273, 0.1957, 2.2119, 2.1048, 1.4881, 1.7567,
         2.2064, 2.1753, 2.2111, 2.1907, 2.1288, 1.8702, 2.1218, 2.1744, 2.2581,
         2.2565, 2.1913, 2.0952, 2.2975, 1.9853, 1.9851, 2.1758, 2.1094, 2.0666,
         2.0578, 1.7592, 2.1246, 2.1616, 2.1781, 2.1823, 2.4415, 2.0122, 1.9394,
         2.1719, 2.1455, 2.3547, 1.0006, 2.1169, 1.6765, 2.2037, 2.1994, 2.2939,
         2.1233, 2.1261, 2.1542, 2.1301, 2.0364, 2.2253, 2.1832, 2.2080, 2.0617,
         2.2758, 2.1373, 2.2573, 2.0367, 2.2055, 2.2531, 1.9362, 2.1346, 2.3110,
         1.8304, 2.2435, 2.0757, 2.1346, 2.0784, 2.2972, 1.9981, 2.2595, 2.3887,
         2.3544, 2.1077, 2.2306, 2.2086, 1.6925, 2.1120, 2.2147, 2.2832, 2.1880,
         2.0909, 2.1869, 2.3249, 2.2425, 2.2322, 2.2695, 2.3331, 0.1346, 0.2001,
         1.9555, 2.0758, 2.0961, 2.2567, 0.4750, 0.5842, 0.7058, 0.7570, 0.9744,
         1.0287, 0.9519, 0.8539, 0.6670, 0.0686, 0.5976, 0.6930, 0.7278, 0.5867,
         0.5813, 0.7097, 0.5000, 0.6474, 0.5425, 0.5578, 0.5803, 0.6271, 0.6408,
         0.0821, 0.6325, 0.8464, 0.9188, 0.7320, 0.1289, 0.1581, 0.7063, 0.8729,
         0.7022, 0.8077, 0.7002, 0.6772, 0.5950, 0.6649, 0.7646, 0.4813, 0.7579,
         0.5831, 0.4914, 0.7263, 0.5337, 0.5253, 0.7073, 0.3907, 0.7041, 0.6702,
         0.4874, 0.5163, 0.2580, 0.6476, 0.5674, 0.4555, 0.5476, 0.5859, 0.6279,
         0.4089, 0.5099, 0.5995, 0.5399, 0.7964, 0.4036, 0.6919, 0.6908, 0.5914,
         0.5730, 0.6122, 0.4277, 0.4590, 0.7666, 0.9008, 0.3882, 0.1257, 0.6154,
         0.6206, 0.1595, 0.6308, 0.4924, 0.5181, 0.5823, 0.2778, 0.8624, 0.2661,
         0.7717, 0.9022, 1.2887, 1.2015, 0.6473, 0.4860, 0.4110, 0.4339, 0.5128,
         1.1724, 1.1852, 1.2922, 1.0709, 1.2392, 1.2499, 1.4100, 1.3137, 0.8466,
         1.4344, 1.4693, 0.6968, 0.1751, 0.1710, 0.1834, 1.4736, 1.6201, 1.3277,
         1.6475, 1.4915, 1.5697, 1.4164, 1.7855, 0.0784, 1.7240, 1.5680, 1.7145,
         1.9040, 1.7964, 1.9526, 1.9328, 2.0737, 1.9253, 1.7730, 2.2707, 2.0602,
         0.4566, 0.5279, 2.1403, 2.0589, 2.0557, 2.1391, 2.1761, 1.8147, 2.0583,
         1.8788, 2.0470, 2.0793, 2.0560, 2.2968, 0.2280, 2.2384, 0.1449, 2.3148,
         2.2568, 2.1043, 2.2506, 0.1906, 2.1942, 2.3548, 2.2405, 2.1008, 2.2179,
         2.2754, 0.6110, 0.2974, 1.9307, 2.1931, 2.0484, 2.0105, 2.1261, 2.0659,
         2.1462, 2.1739, 1.9466, 2.4105, 2.2565, 2.0342, 2.1688, 0.2608, 2.0383,
         2.0664, 1.9995, 2.1393, 2.2680, 2.0550, 2.2346, 1.9870, 2.0796, 1.9112,
         1.2930, 2.2390, 2.1678, 0.1801, 2.0002, 1.6783, 2.0918, 2.3177, 1.8342,
         1.9244, 1.9471, 2.2717, 2.1227, 2.2932, 2.3473, 1.7774, 2.0945, 2.3712,
         2.1550, 2.0802, 0.1087, 2.2277, 1.9290, 2.2212, 2.0705, 1.8797, 2.1542,
         0.3608, 2.1922, 2.1362, 2.1825, 1.9593, 2.1429, 0.2623, 1.6499, 2.0807,
         2.0261, 2.1480, 1.9283, 0.1497, 2.1901, 2.0398, 2.0140, 2.5195, 2.0685,
         1.4206, 2.0745, 2.2225, 0.1621, 2.2012, 0.4932, 2.0481, 2.1097, 2.3599,
         0.4743, 1.9034, 2.2135, 2.0947, 2.1751, 1.7660, 2.4012, 2.1536, 1.9608,
         2.1268, 1.9698, 2.2014, 2.3058, 2.1618, 1.8719, 1.9626, 2.2343],
        requires_grad=True),
 Parameter containing:
 tensor([ 7.7839e-02,  1.4282e-01, -6.7494e-02,  6.3598e-02, -1.2071e-01,
          7.2978e-02, -1.3550e-01,  3.5607e-02, -4.4458e-02, -4.4257e-03,
         -3.3140e-01, -8.8216e-02,  2.0695e-01, -1.7521e-01, -7.0075e-02,
         -2.3476e-01, -3.5785e-01, -4.3914e-01,  1.4167e-01, -9.0072e-02,
         -1.6590e-01, -2.4325e-02, -9.6055e-02, -3.2896e-01,  1.2258e-02,
         -1.0973e-03,  2.2662e-01, -1.3086e-02, -2.1918e-01, -4.5178e-02,
         -1.9418e-01, -1.8878e-02, -1.3459e-02, -2.9698e-01, -3.9941e-02,
         -9.4998e-02, -1.9507e-01, -4.1943e-02,  1.6590e-01, -1.1282e-01,
          1.1039e-01,  2.5711e-02, -1.5641e-01,  5.5295e-02, -1.1544e-01,
         -1.7157e-01, -2.8929e-01,  2.3132e-01, -2.6698e-01, -2.9870e-02,
         -1.4797e-02, -1.4169e-01, -4.8199e-03,  1.4835e-02, -8.9909e-02,
         -4.6198e-02, -2.8071e-01, -4.3290e-01, -1.6699e-01, -2.0422e-01,
         -5.8818e-03, -2.2520e-01,  6.2375e-03,  3.9504e-02,  7.5439e-02,
         -1.4287e-01, -5.1881e-01, -5.5721e-02, -5.5866e-02, -5.3829e-01,
         -1.4044e-02, -9.2953e-02, -1.1587e-01, -3.8476e-02, -2.5480e-01,
         -5.7539e-02, -2.8871e-01,  3.5020e-02, -3.1672e-02, -5.8393e-02,
         -3.2713e-01, -1.8932e-01,  8.0913e-02, -4.6087e-01, -6.5291e-02,
         -4.0539e-01,  6.4874e-02, -1.7552e-01,  4.5883e-02,  9.9371e-03,
          1.4575e-02, -1.4779e-01,  3.0300e-01, -7.1591e-02, -3.0603e-02,
         -5.1550e-02,  3.3196e-01, -2.6409e-01, -1.0252e-01, -9.0839e-02,
         -6.5229e-02,  4.6278e-03,  6.9909e-01, -3.8764e-01, -1.8178e-01,
         -1.6395e-01, -4.2978e-01, -9.3517e-02, -2.7543e-02, -1.2259e-01,
         -2.8473e-01,  2.5956e-01, -2.6014e-01,  5.4886e-02, -2.7227e-02,
         -1.3363e-01, -1.5168e-01,  8.5377e-02,  2.9195e-01,  2.1162e-02,
         -3.9784e-02,  4.4097e-02,  9.6993e-02,  1.4139e-01,  2.4818e-01,
          1.8267e-02, -1.1592e-01,  1.0816e-01,  7.5200e-02, -1.3003e-01,
         -8.0244e-03,  5.8670e-02, -3.7428e-01,  2.2588e-01, -5.0269e-01,
         -2.3895e-01,  8.2600e-02, -6.8347e-02, -1.0482e+00, -1.3551e-01,
          1.1412e-02, -2.1185e-01, -2.4042e-01,  2.4737e-02, -2.5176e-01,
          1.5020e-01, -2.3560e-01,  1.1241e-01, -6.4413e-02, -3.5118e-01,
         -1.2333e-01,  1.9045e-01,  4.3384e-02, -2.4544e-01, -4.2071e-01,
         -7.8986e-02, -4.2295e-02,  4.0794e-01, -3.2176e-01, -5.9337e-01,
         -6.2764e-02,  1.0759e-01, -4.0607e-01, -1.3816e-01, -3.3327e-01,
         -2.0288e-01, -3.6235e-01, -4.0601e-01, -3.3251e-01,  1.0679e-01,
         -3.2651e-01, -4.5523e-01,  9.8463e-03, -1.7090e-01,  5.6157e-02,
         -2.0125e-01, -1.0815e-01, -1.1430e-01, -4.0327e-02, -3.9167e-01,
         -3.6428e-01, -4.4570e-01, -8.9959e-02, -4.9760e-01, -1.0579e-01,
          1.3707e-01, -7.0252e-02,  2.7966e-02, -2.4773e-01, -5.1971e-04,
          8.3816e-02,  2.1685e-02, -6.9780e-01,  2.2206e-02,  3.3752e-01,
         -3.2891e-01, -7.8279e-02,  3.3331e-03, -1.5812e-01, -7.3529e-02,
         -2.4885e-01,  7.1563e-03, -1.0669e-01, -9.1697e-02,  3.7219e-02,
          2.2590e-01, -3.7476e-01,  8.3716e-02,  5.5841e-02, -3.0678e-01,
         -3.4485e-01, -4.4003e-01,  1.9830e-01, -4.7639e-01, -6.4421e-02,
         -2.7313e-01, -1.4385e-01, -1.4548e-01, -3.6821e-01,  2.6972e-01,
         -3.0483e-01,  4.1683e-02, -2.3375e-02, -2.3032e-01, -4.5438e-01,
         -2.6145e-01, -2.2000e-01, -5.7517e-02,  4.7594e-02, -9.9610e-03,
         -4.2952e-01,  1.8124e-01, -1.1407e-01, -2.7262e-01, -1.1815e-01,
         -2.3155e-01, -4.2597e-01, -4.4960e-01, -1.8752e-01, -3.0844e-01,
          3.5617e-02, -3.7852e-01, -3.3136e-01, -1.9491e-01, -2.1862e-01,
         -3.3167e-01,  2.6676e-01, -1.9840e-01, -3.3605e-01, -1.6330e-01,
         -6.2717e-02, -8.3715e-01, -2.5243e-01, -1.3302e-01, -3.6257e-01,
          5.8300e-01, -1.1160e-01, -1.1229e-01, -4.1968e-01, -1.0799e-01,
         -1.9890e-01, -6.1067e-02, -2.9817e-01, -6.8028e-02, -1.3047e-01,
         -8.3282e-01, -2.1888e-01, -1.1378e-01, -1.4994e-02, -3.3752e-01,
          1.4736e-01, -2.0098e-01, -3.8907e-01,  1.4387e-01, -1.3784e-01,
          1.6391e-02, -1.7244e-01,  7.5800e-02, -2.3648e-01, -3.8036e-01,
          1.9662e-01,  7.4968e-02, -1.1686e-01, -3.6071e-01, -7.9299e-02,
          1.8760e-01,  1.6195e-01, -3.2272e-01, -2.1438e-01, -7.2898e-02,
          9.8829e-02,  7.1539e-02,  1.3703e-01, -1.5568e-01,  6.3408e-04,
         -3.5787e-02, -2.7407e-01, -5.7378e-02, -2.0438e-01, -2.4371e-02,
          1.7313e-01, -4.1306e-01, -9.4938e-02,  3.8556e-02, -2.3727e-01,
          5.0274e-02, -5.2022e-02,  6.9763e-03,  1.2209e-01, -1.4279e-01,
         -2.5014e-01, -1.8495e-02, -1.3463e-02, -2.4504e-01, -1.3166e-01,
         -7.7291e-02,  7.7370e-02,  1.1513e-02, -7.0425e-02,  1.5736e-01,
         -2.1174e-01, -4.2664e-02, -2.9207e-01,  3.2393e-02,  2.1656e-02,
          9.9900e-02, -1.3805e-01,  2.5438e-01,  2.0831e-01,  3.6837e-02,
         -3.3914e-03,  4.1395e-01,  5.6420e-02,  8.9263e-02,  2.1450e-02,
         -5.5800e-02,  7.0606e-02, -4.1126e-02,  3.8725e-03, -1.5734e-01,
          5.0738e-01,  1.5756e-02,  3.4117e-01, -3.4182e-01,  2.3014e-01,
          2.9587e-02, -8.8264e-02,  3.3711e-01, -1.4313e-01,  1.5262e-01,
         -8.7762e-02,  2.4450e-01, -2.0987e-01,  1.9820e-01,  1.7844e-01,
          1.4303e-01, -5.0851e-02, -9.4576e-02,  1.8408e-02,  1.1286e-01,
          3.3272e-01,  3.5103e-01, -4.2428e-02, -1.9907e-01,  9.6479e-02,
          3.2967e-02, -1.9729e-01,  2.2756e-01,  8.3037e-02,  2.5401e-01,
          2.9031e-01, -1.5839e-01, -1.3418e-02,  1.0571e-01, -3.5190e-01,
         -8.5125e-02,  1.5848e-01,  2.5322e-01,  2.0388e-02,  1.4573e-01,
          1.7365e-02,  3.1611e-01, -2.0127e-01,  8.0616e-02, -1.4502e-02,
          6.7866e-01,  5.2572e-01,  6.3858e-02,  3.9846e-02,  5.1869e-01,
         -7.9728e-03,  3.9597e-01,  4.7967e-01,  2.7590e-01,  9.2782e-02,
          3.3310e-01,  2.2875e-01,  3.4428e-01,  4.6610e-01, -1.0366e-01,
          3.4020e-01,  2.3838e-01,  3.1878e-01,  1.2648e-01,  5.1629e-01,
          3.4091e-01,  4.3710e-01,  6.2221e-01,  1.7226e-01,  4.4662e-01,
          4.0081e-01,  7.7952e-01,  4.0586e-01,  1.0278e+00,  3.0402e-01,
          1.5113e-01,  8.0986e-02,  2.9811e-01,  6.0928e-01,  3.3816e-01,
          5.8209e-01,  5.3371e-01,  3.8662e-01,  2.0641e-01,  3.6023e-01,
          3.1196e-02,  4.9345e-01,  3.2226e-01,  2.7840e-01,  2.7691e-01,
          9.6109e-01,  1.2737e-01,  4.1566e-01,  3.9062e-01,  3.0825e-01,
          4.9397e-01,  4.5440e-01,  5.2856e-01,  2.1089e-01,  4.5024e-01,
          3.9093e-01,  4.3543e-01,  1.2896e-01,  3.8236e-01,  5.7791e-02,
          5.9610e-02,  3.2190e-01,  4.1077e-01,  6.7217e-01,  3.1503e-01,
          4.5539e-01,  3.8127e-01,  3.7299e-01,  4.9606e-01,  5.1592e-01,
          8.7739e-01,  1.2913e-01,  3.2640e-01,  5.1213e-01,  2.5983e-01,
          3.1244e-01,  8.0140e-02,  3.2804e-01,  1.5592e-01,  4.3599e-01,
          5.4296e-01,  3.3799e-01,  5.6262e-01,  9.3698e-01,  4.7990e-01,
          4.9927e-02,  4.0214e-01,  5.5437e-01,  4.3915e-01,  1.3080e-01,
          3.5957e-01,  6.5735e-02,  9.8948e-02,  4.7541e-01,  9.1836e-02,
          3.4417e-01,  3.5615e-01,  4.0770e-02,  4.5717e-01,  6.4114e-01,
          2.4542e-01,  5.0354e-01,  1.7951e-01,  6.0904e-01,  1.5958e+00,
          2.1165e-01,  3.6238e-01,  2.0053e-01,  4.2348e-01,  6.8393e-01,
          8.5349e-01,  1.3414e-01, -1.2184e-03,  4.1054e-01,  7.6441e-01,
          6.1769e-02,  3.8833e-01,  3.6897e-01,  3.5290e-01,  2.8261e-01,
          3.1730e-01,  4.8138e-01, -1.5993e-01,  3.7400e-01,  2.7083e-01,
          2.0941e-01,  5.4596e-01], requires_grad=True)]

Looking at shape of last layer

# collapse_output
final = model.get_submodule("model.decoder.layers.5.final_layer_norm")
final_paramaeters = list(final.parameters())
print(f"{final_paramaeters = }")
final_paramaeters = [Parameter containing:
tensor([ 9.2454,  9.3895,  9.3544,  9.0685,  9.2224,  9.8569,  9.3900,  9.4416,
         9.4985,  9.2981,  9.5326,  9.2260,  8.8878,  9.4862,  9.5422,  9.3088,
         9.6653,  8.9836,  9.5670,  9.0307,  9.4179,  9.8929,  9.3411,  8.9442,
         8.3855,  9.0165,  9.5142,  9.5201,  9.2902,  9.5196,  8.8687,  9.3270,
         8.7709,  9.5791,  9.4227,  8.9457,  9.4278,  9.2320,  9.5537,  9.3045,
         9.2281,  9.1897,  8.9683,  9.3930,  9.1265,  9.2261,  9.1755,  9.2192,
         9.1531,  9.2323,  9.1581,  9.3413,  8.4585,  9.3836,  9.7359,  8.8970,
         9.4054,  8.9220,  9.2355,  9.6045,  9.6126,  9.4839,  9.2955,  9.2803,
         9.5649,  8.8892,  9.4749,  8.8119,  9.3922,  9.0771,  9.7973,  8.9035,
         9.7339,  9.1203,  9.5283,  8.9696,  8.4717,  9.3626,  9.3828,  7.9538,
         8.8453,  9.0190,  9.3108,  8.3297,  8.7236,  8.8562,  9.1680,  8.8641,
         7.8828,  8.7943,  8.4220,  8.8387,  9.3143,  8.1786,  9.1979,  9.0642,
         8.2838,  8.6224,  8.8548,  8.2028,  8.3914,  9.4564, 10.2469,  9.0537,
         8.7376,  9.3791,  8.5842,  8.4631,  8.6599,  8.8171,  7.8897,  8.6041,
         8.4556,  8.9208, 10.1143,  7.9758,  8.2237,  8.5698,  9.2252,  8.1479,
         8.0188,  8.9071,  8.1475,  9.6910,  8.2373,  8.2525,  8.6017,  8.4775,
         7.6445,  8.5943,  8.4234,  9.5359,  7.9101,  9.0395,  8.2788,  9.1683,
         8.9006,  9.3443, 10.6461,  8.7802,  8.7067,  8.1328,  8.4786,  9.5398,
         8.9038,  8.7195,  8.6432,  8.6484,  8.0920,  7.6238,  8.0674,  9.1098,
         8.9414,  8.5768,  8.5224,  8.2418,  8.2112,  8.5999,  8.4768,  8.9988,
         9.0594,  8.4397,  7.2651,  8.8350,  8.4989,  8.2867,  9.2490,  8.9484,
         9.0761,  9.4235,  8.6788,  8.3734,  8.5445,  8.6480,  8.5919,  8.7318,
         8.9115,  8.3845,  7.7635,  8.0614,  8.0440,  8.3904,  9.2142,  8.9592,
         8.3101,  8.5018,  8.3161,  8.6132,  8.5134,  8.6191,  9.2030,  8.4010,
         8.6543,  8.9678,  8.5206,  8.7887,  8.4305,  8.9793,  8.4836,  8.3803,
         8.5192,  9.0187,  8.2780,  8.4214,  8.5277,  8.3268,  8.6899,  8.8909,
         8.5217,  8.8556,  8.1597,  9.0187,  8.8114,  9.0544,  8.1888,  8.0256,
         8.2712,  7.8735,  8.3806,  8.3239,  8.1951,  8.1542,  8.8955,  8.1172,
         8.7627,  8.6084,  8.8146,  8.5941,  8.4780,  7.9555,  8.5277,  8.8061,
         8.1250,  8.5714,  8.6387,  7.6968,  8.5164,  8.5684,  8.8306,  8.1602,
         8.7625,  8.7649,  8.5770,  8.8186,  8.6728,  8.8203,  8.8378,  8.8105,
         8.2568,  8.4017,  9.9819,  9.0695,  8.9472,  8.4494,  7.6861,  8.1042,
         9.4347,  9.3720,  9.0644,  9.1978,  9.8322,  9.0001,  9.1845,  9.4331,
         9.3469, 11.0728,  9.3463,  8.5851,  9.6459,  9.1978,  9.2272,  9.5648,
         9.5100,  9.6435,  9.5191,  9.8178,  9.3789,  9.5861,  9.2071,  9.2581,
         8.5441,  9.6824,  9.0314,  9.2823, 10.2148, 10.1498,  9.3458,  8.9451,
         9.7831,  9.0849,  8.7979,  9.0224,  8.8580,  9.6999,  9.0158,  9.4426,
         9.2253,  9.1951,  9.4550,  9.1783,  9.5661,  9.3228,  9.4391,  9.2358,
         9.1685,  8.8517,  9.4883,  9.0652,  9.4498,  8.6077,  9.7002, 10.4473,
         9.9884,  8.8662,  9.4317,  9.2922,  9.0668,  9.7620,  9.2281,  9.4860,
         9.6106,  8.0309,  8.9221,  9.0221,  9.0459, 10.2337,  9.7973,  9.5885,
         9.0249,  8.8571,  8.7396,  8.9452,  9.2020,  9.1573,  8.4453,  9.3205,
         8.6279,  8.8441,  8.9208,  9.7410,  8.9751,  9.3891,  9.5010,  8.9050,
         8.8219,  8.4705,  9.4688,  9.2351,  9.1935,  9.7405,  9.1623,  8.1793,
         8.0767,  8.1733,  8.9422,  8.4693,  8.9346,  9.1120,  8.0441,  9.5878,
         9.5636,  8.8612,  9.0740,  9.1084,  9.7573,  9.8492,  9.6772,  9.1868,
         8.7703,  8.4915,  8.4426,  8.7710,  9.0574,  8.4157, 10.3115,  9.0996,
         8.5651,  9.0585,  8.4534,  8.7063,  8.4291,  8.3241,  7.9195,  9.0210,
         8.3222,  8.5985,  8.7874,  9.1164, 10.2389,  7.7741,  8.5940,  9.1308,
         9.3498,  8.7384,  8.3300,  8.2650,  8.7969,  8.6335,  8.6550,  8.7559,
         8.2821,  8.7692,  8.7830,  8.4424,  8.6879,  8.6025,  8.6327,  8.8367,
         9.4620,  8.5763,  8.3675,  8.4179,  9.2793,  8.8078,  9.3775,  9.6580,
        10.1902,  8.9006,  8.5452,  8.6059,  8.5685,  8.4081,  9.1445,  8.5781,
         8.9791,  8.7608,  8.6678,  8.4435,  7.6760,  8.6099,  8.8083,  8.1700,
         8.5081,  8.1777,  9.2411,  8.9585,  8.1853,  8.3657,  7.9898,  8.8000,
         8.1188,  9.3628,  8.9330,  7.7698,  9.6513,  9.2959,  9.1233,  9.0433,
         8.2871,  8.7241,  8.2236,  8.3967,  8.2571,  9.3786,  8.6354,  8.7345,
         8.3856,  8.4556,  8.7689,  8.7359,  8.6211,  9.7834,  8.9445,  8.8958,
         8.1290,  8.5490,  9.0263,  8.3258,  8.2379,  8.8249,  8.7301,  8.6340,
         9.3168,  8.7775,  9.9242,  8.9798,  9.1412,  8.5955,  8.1734,  8.9969,
         9.5123,  9.0581,  8.2497,  8.3555,  9.3501,  8.7719,  8.4376,  8.8456,
         8.2080,  8.9806,  8.5660,  9.1352,  8.5920,  8.2595,  8.1272,  9.0418,
         8.6972,  8.3413,  8.2742,  8.3118,  8.2167,  8.5550,  8.7187,  8.8749,
         9.7556,  8.4383,  9.0293,  8.1725,  8.5115,  8.9174,  8.9519,  9.0915],
       requires_grad=True), Parameter containing:
tensor([-1.6685e+00, -6.0155e-01, -5.9975e-01,  8.4297e-01,  8.5853e-01,
         5.6530e-02, -1.2840e+00, -5.1519e-01,  1.6774e+00,  3.2501e-01,
         1.4737e-01, -9.6427e-01,  2.1513e-01,  9.5219e-01, -3.7011e-03,
         6.6861e-01,  7.9758e-01,  2.4703e-01, -9.5743e-02,  1.9413e-01,
        -4.1348e-01, -8.3267e-01,  9.7684e-01, -5.1446e-01,  5.3158e-01,
         1.0447e+00,  1.7422e-01,  1.8719e+00,  7.0798e-01, -5.2600e-01,
         3.0636e-01,  3.1010e-01, -6.3830e-02, -2.3082e-01,  1.1787e+00,
        -2.5507e-01, -1.2747e+00,  7.3436e-01, -6.5267e-01,  1.0654e+00,
         7.2399e-01, -1.2560e+00, -6.7986e-01, -2.0358e-01, -2.1730e-01,
         5.1018e-02,  3.6179e-01,  2.0001e+00, -6.3287e-01,  1.5726e+00,
         2.8116e-01, -5.0017e-01, -1.6484e+00, -9.0159e-01, -2.5041e-01,
        -1.7400e-01,  6.4630e-01,  5.9313e-02, -7.2617e-03,  5.0565e-01,
         1.8716e+00, -8.8190e-01, -1.5941e-02,  7.8757e-02, -7.3102e-01,
        -4.5485e-01,  1.1036e+00, -3.2698e-01, -8.0969e-01, -6.6129e-01,
        -6.8337e-01, -1.6216e-01, -9.3829e-02,  6.4593e-01, -1.3784e+00,
         5.6243e-01,  8.1852e-01,  1.3817e-01,  5.7122e-01, -7.8534e-01,
        -9.2640e-01,  1.3659e-01, -6.8277e-01,  8.1809e-01, -1.4720e-01,
        -2.1538e+00, -7.1303e-02,  3.9166e-01, -7.9192e-01,  1.0671e+00,
         1.1110e+00,  9.8533e-01, -4.9213e-01, -8.4603e-01, -1.1119e+00,
         1.6191e+00,  7.9375e-02, -1.0472e-01, -5.4553e-01, -2.3597e-01,
        -2.6790e-01, -1.5157e+00, -2.6880e+00,  1.6904e-01,  2.3876e-01,
        -5.1432e-01,  5.7074e-01,  1.5021e+00, -1.7612e+00, -5.1162e-01,
         1.8071e+00, -2.2087e-01,  2.1651e-01,  3.1280e-01, -7.8104e-01,
        -2.3347e-01,  2.3287e+00,  4.3430e-01,  6.7748e-02, -7.1022e-01,
         1.3716e+00, -6.8236e-01, -1.9249e-02,  6.1708e-01,  3.5377e-01,
        -3.0060e-01,  8.7717e-01,  7.6281e-02,  1.6436e+00,  6.5745e-02,
         1.3911e+00, -1.1550e+00, -1.0942e+00, -5.4705e-02, -3.8439e-01,
        -2.0564e-01, -4.0284e-01,  1.8441e+00,  1.9942e+00, -3.3832e-01,
        -8.4892e-02,  2.6425e-01, -1.2417e-01, -8.9078e-01,  9.9491e-01,
        -1.2496e-01,  1.8860e-01, -1.9992e-01,  1.2828e+00, -1.6894e+00,
         1.7569e+00, -1.2428e-01, -6.2974e-01,  9.5339e-01,  5.5913e-01,
         8.3872e-01,  3.8710e-01,  4.7107e-01,  8.8813e-01,  1.5112e+00,
         6.4772e-02,  2.2407e+00, -2.4373e+00,  5.4596e-02, -2.3119e+00,
         7.8280e-01, -1.9582e+00, -4.4601e-01, -7.2071e-01,  1.0691e+00,
        -6.3960e-01, -9.6271e-01,  2.2167e+00,  1.6286e+00,  1.8287e-01,
        -1.0599e+00,  8.2727e-01,  4.2197e-01, -1.7488e-01,  2.2607e+00,
         1.6864e+00,  1.5625e+00, -2.4543e-01,  1.7482e-01, -1.4680e+00,
        -6.5810e-01, -1.7268e-01,  4.3401e-02,  1.2926e+00,  4.0332e-01,
         1.2770e-01, -5.4604e-02,  6.3163e-01,  5.8788e-01,  3.2761e-01,
         5.9546e-01, -1.4995e-02, -2.2789e-01, -3.0784e-01, -1.0060e-01,
        -1.6770e-01, -1.0096e+00,  9.2021e-01, -8.9897e-01, -5.9694e-01,
         8.2038e-01, -9.0749e-01, -3.0484e-01,  3.2038e-01,  1.2042e+00,
         6.0027e-01,  1.8709e-02, -4.0982e-01,  9.0638e-01, -9.6504e-01,
        -6.3824e-01, -2.3503e-02, -2.9762e-01,  1.1074e+00,  1.2170e-01,
         1.1205e+00, -1.9938e-01, -2.7814e-01, -3.8689e-01,  1.1914e+00,
        -6.5604e-01,  7.1130e-02, -7.0655e-01,  1.4939e+00, -2.6654e-01,
         4.9578e-01, -1.8316e+00, -6.2531e-01,  2.2550e+00, -9.1826e-01,
         2.1526e+00,  1.7631e-01,  1.2235e+00, -9.9429e-01, -8.9968e-01,
        -9.7487e-01, -3.5716e-01, -3.8364e-01, -2.2766e+00, -1.4803e+00,
         2.7549e-01, -5.8828e-01, -4.4274e-01,  2.0661e-02,  9.6894e-01,
        -5.4657e+00,  3.6806e+00, -5.8913e-01,  6.1390e-02,  9.8940e-01,
         1.8229e+00,  3.6467e-01,  2.9497e-01,  2.1930e+00,  1.8576e+00,
        -7.6800e-01,  1.3635e+00,  2.8457e-01,  2.9478e-02, -1.5696e+00,
         6.0662e-01, -1.1586e+00,  7.8294e-01,  3.4371e-01,  1.4571e-01,
        -4.5860e-01, -1.1644e+00, -1.2903e-01, -1.0055e+00, -5.4373e-02,
         1.3311e+00, -1.2074e+00,  8.7602e-02,  8.2454e-01, -2.2496e+00,
         2.4152e+00, -7.4065e-02,  3.5327e-01,  1.2092e+00,  6.9553e-02,
         2.4961e+00, -1.5597e+00,  4.1607e-01, -7.9795e-02, -4.4723e-01,
         2.6720e-01, -1.9072e+00, -6.5835e-01, -2.5336e-01, -9.1617e-01,
         8.8624e-01, -6.2251e-01,  9.6169e-01,  1.1279e+00, -5.6577e-01,
         1.8407e-01,  6.5294e-01, -6.1990e-01,  7.9014e-01, -6.0878e-01,
         1.0077e+00,  1.2790e+00, -1.3704e-02,  7.4945e-02,  5.6748e-01,
         1.0100e+00, -2.2963e-01, -9.2723e-01, -3.3553e-01, -7.0238e-01,
        -2.3026e+00, -5.3322e-02, -9.2703e-01,  1.4448e+00, -8.7800e-01,
        -6.4034e-01, -1.2203e+00, -1.1720e+00,  4.9662e-01,  3.4336e-01,
        -1.3538e+00,  4.1525e-01, -6.6715e-01,  4.1263e-01, -4.0352e-01,
        -3.7377e-01,  2.3441e+00,  3.5528e-01, -3.1402e-01,  3.5890e+00,
         2.8886e-02,  3.1700e-01, -7.7702e-01,  4.6834e-01,  5.4264e-01,
        -1.0964e+00,  1.4711e+00,  9.3168e-01,  5.4778e-01, -7.4466e-01,
         7.7792e-01,  1.5176e+00,  1.6450e+00,  2.6295e-02, -1.8510e+00,
         2.2687e-01,  1.3993e-01,  1.1727e+00,  6.4835e-02,  1.9505e-01,
         2.2950e-01, -1.3806e+00,  7.7071e-02, -1.8424e+00, -9.5833e-01,
        -8.7708e-01,  9.1619e-01,  1.0074e+00,  8.0151e-03,  1.0098e+00,
        -3.9247e-02, -2.7759e-01, -2.1021e+00, -4.1539e-01, -1.5258e-01,
         3.3655e-01, -2.6506e-01,  2.1964e+00,  6.0517e-01,  5.7097e-01,
         7.5984e-02,  1.0848e+00,  4.8223e-01,  8.0175e-01, -9.1310e-01,
         6.3781e-01,  1.1286e-01,  1.3899e+00, -4.5585e-01, -8.9240e-01,
        -5.6478e-01, -1.0510e+00, -6.3237e-01,  7.5205e-01, -5.0555e-01,
        -4.2338e-01,  1.1653e+00, -4.3769e-01, -4.9660e-01,  8.4734e-01,
         3.1255e-01,  1.4222e+00,  5.1850e-01,  5.9261e-03,  6.8774e-01,
        -2.2485e+00, -2.1259e-01,  1.7378e-01, -3.9461e+00,  8.5505e-01,
        -1.4455e+00,  2.2031e-02, -8.7173e-01,  9.4395e-01,  1.3690e+00,
         9.2501e-01,  5.9211e-01,  5.9655e-01, -9.7749e-01,  5.1079e-01,
         1.7735e-02,  3.1332e-01,  2.8223e-01,  2.2100e-01,  9.7640e-01,
         7.5128e-01, -1.2068e+00,  8.0254e-01,  4.7232e-01, -5.7225e-01,
         3.0082e-01, -4.5279e-01, -3.4367e-01, -2.8903e-01,  1.1790e+00,
        -2.3224e+00,  7.0363e-01,  4.5137e-01,  1.5505e+00,  8.4144e-01,
         3.9210e-02, -9.5217e-01, -9.1495e-01, -3.6971e-01,  1.3037e-01,
         1.0739e+00, -5.2155e-02, -1.7844e+00, -6.9291e-01,  6.2565e-01,
        -1.6121e+00, -4.0668e-01,  6.9844e-01,  2.1026e-01, -3.4400e-01,
        -2.3706e-02, -4.4798e-01,  6.0481e-01,  7.8424e-01,  6.2746e-01,
        -7.7199e-01,  2.0300e-01,  9.1969e-01, -1.1502e+00, -3.1036e-01,
         3.8410e-01,  3.3024e+00,  9.6322e-02,  3.5212e-01,  1.4104e+00,
        -2.7992e-01,  4.1524e-01, -1.1456e+00, -2.6424e-01, -6.5836e-02,
        -5.0440e-01,  5.7824e-01, -7.8925e-01, -2.0960e+00, -1.2973e-01,
         1.0862e+00,  1.3762e+00, -3.2528e-02, -2.2924e+00, -8.9146e-01,
        -3.0597e+00,  6.0693e-01, -2.5389e-01, -2.9927e-01,  3.3115e-01,
        -4.1729e-01,  1.3418e+00,  8.3576e-01, -1.0882e+00,  1.0617e+00,
        -2.8175e-01,  1.1439e+00, -4.9022e-01, -1.1799e-01, -4.8219e-01,
         9.3034e-02,  1.2776e+00, -1.2725e-01,  5.8007e-01,  1.3756e+00,
         1.2398e-01, -3.1594e-01, -7.3134e-02,  4.6101e-01,  1.4797e-01,
        -8.3583e-01, -1.8117e+00,  1.3540e-01,  1.4121e-01,  5.1246e-01,
         1.6791e-01, -1.5676e+00], requires_grad=True)]

Looking at how neural network really work?

partial in python is something which is usually used in lot of languages. It’s just subsituting value of x, in a function which is partial filled with already existing function.

Jeremy explained with this notebook today

I followed along the notebook, making slight changes in variable names, function names, along with changing defined value to expirement with how does a neural network really work notebook version on my own.

Note

Thanks to Alex Strick for sharing this trick when working with notebooks during delft-fastai sessions.

Important

Usually when I explain this is how neural networks work, one of my students said this is like how we draw an outline of an owl. In deep learning …, there is just step 1 where we draw outline, computer automatically does the step2 by drawing a beautiful owl.

Intuitive understanding with neural networks

Using RELUs we can tweak our function in such a way to fit the data. What neural networks, with a bunch of RELU functions does is it helps to optimize in such a way to fit any swiggly line or complex things which needn’t be always quadratic.

Important

For Linear algebra, almost all time you need is matrix multiplcation. In schools, you learn linear algrebra as if you need tons of experience to do machine learning. Yet it’s this operation of matrix multiplication that GPUs are so good at it, and there are even tensor cores for this.

Refresher on matrix multiplication

Using Titanic dataset,see who survived and who didnt’ with excel to understand neural networks. In video from [1:05:00].

Next week, we are going to look into how validation sets and more into metrics. We will be looking into Kaggle notebook on how to get started with NLP.