Skip to content

Refactor LLourney to be more self-contained.#1

Open
tmmnmmn wants to merge 3 commits into
veezbo:mainfrom
tmmnmmn:add-training-script
Open

Refactor LLourney to be more self-contained.#1
tmmnmmn wants to merge 3 commits into
veezbo:mainfrom
tmmnmmn:add-training-script

Conversation

@tmmnmmn

@tmmnmmn tmmnmmn commented Aug 20, 2023

Copy link
Copy Markdown

Refactor the LLourney class to be configured through arguments to __init__ rather than constants hardcoded in model.py.

Comment thread model.py
denoised_latent_img = model(img, [input_str]*BATCH_SIZE, [1]*BATCH_SIZE)
# print(denoised_latent_img.shape)
@torch.no_grad()
def forward_pass_test(model: LLourney, input_str: str, batch_size: int = 2, device="cpu"):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for getting this working again

Comment thread model.py Outdated
# )
test_model = LLourney(
image_dim=64,
# vae_model_id_or_path="test_vae",

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the "test_vae" you defined above around for the test? Or is it the default 'runwayml/stable-diffusion-v1-5' small enough?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the current LLourney parameters for the test are meant to work for a small VAE; I believe if you use a 512 x 512 image and the VAE from runwayml/stable-diffusion-v1-5, the patch embedding (with patch_dim=2) will produce way more tokens than the hf-internal-testing/tiny-random-gpt2 testing checkpoint's context window can handle. So it's probably best to stick with test_vae, and it makes the test faster and less resource-intensive to run since it's smaller. (I guess it could be worth uploading it to the hub so we don't have to keep around the code to reproduce it.)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to use it for the test, but then should we uncomment out the code that creates/saves it? Maybe with a check to see whether it has already been saved.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a simple check to see if the test_vae directory exists and if not generate the random small VAE checkpoint for testing.

Comment thread model.py
@torch.no_grad()
def forward_pass_test(model: LLourney, input_str: str, batch_size: int = 2, device="cpu"):
# Prepare random image in pixel space
img = torch.randn((batch_size, 3, model.image_dim, model.image_dim), device=device)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make this a FloatTensor, could add .double() here to make the types line up

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand the comment, torch.randn should use dtype = torch.float (= torch.float32) by default. So it's already a torch.FloatTensor, and I'm not sure casting it to a double (= torch.float64) would make sense here.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically Pycharm complains at the typing (it doesn't believe img is a FloatTensor) which could be its own fault. But adding .double() suppresses that error. Not a big deal in any case.

Comment thread model.py Outdated
from typing import Optional

# Input image dimension
IMG_DIM = 128

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you delete the constants no longer needed here?

Comment thread model.py
Comment thread model.py
# projected_img_emb = projected_img_emb.reshape(
# shape=(B, self.latent_dim // self.patch_dim, self.latent_dim // self.patch_dim, self.patch_dim, self.patch_dim, self.vae_latent_channels)
# )
denoised_latent_image = rearrange(

@veezbo veezbo Aug 26, 2023

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you adjust the shaping comments accordingly? Perhaps explain a bit more about what's going on here, though it's likely a lot clearer than before.

Comment thread model.py
Comment thread model.py Outdated
else:
latent_image = self.vae.encode(img).latent_dist.mean
if scale_latents:
latent_image = latent_image * self.vae.config.scaling_factor

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, it would be better to use self.vae_scale_factor instead of self.vae.config.scaling_factor

@tmmnmmn tmmnmmn Aug 28, 2023

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.vae_scale_factor and self.vae.config.scaling_factor are two different quantities: self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) (= 8) is the factor by which the VAE downsamples the spatial resolution of the input image, whereas self.vae.config.scaling_factor (= 0.1825) is a scalar which scales the output of the VAE encoder. So I believe using self.vae.config.scaling_factor is correct here.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize self.vae_scale_factor and self.vae.config.scaling_factor were different. Can you rename them so their respective usages are more clear?

Comment thread model.py Outdated
@torch.no_grad()
def decode_image_latents(self, denoised_latent_image: Tensor, as_numpy: bool = True) -> np.ndarray:
def decode_image_latents(self, denoised_latent_image: torch.FloatTensor, as_numpy: bool = True) -> np.ndarray:
denoised_latent_image = 1 / self.vae.config.scaling_factor * denoised_latent_image

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, it would be better to use self.vae_scale_factor instead of self.vae.config.scaling_factor

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1 (comment). I believe using self.vae.config.scaling_factor is correct here.

@veezbo veezbo left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good functionally, only a few minor comments

Comment thread model.py Outdated

# Prepare input tokens and attention mask (for padding)
input_str_list = [input_str] * batch_size
input_ids, attention_mask = model.tokenize(input_str_list)

@veezbo veezbo Aug 26, 2023

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's used as a mask of the attention, but I think calling it pad_mask here is more in-line with the output from a tokenizer. Feel free to override though.

Comment thread model.py
Comment thread model.py
Comment thread model.py

@veezbo veezbo left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one renaming request and maybe a couple minor suggestions. looks good

@veezbo

veezbo commented Sep 16, 2023

Copy link
Copy Markdown
Owner

@dg845 feel free to merge when you're ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants