Fix some deprecated types

This commit is contained in:
a666
2023-08-25 01:58:19 -06:00
parent 84d41e49b3
commit b6c1a1bbbf
7 changed files with 34 additions and 38 deletions

View File

@@ -15,7 +15,7 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, List
from typing import Optional, NamedTuple
def narrow_trunc(
@@ -97,7 +97,7 @@ def _query_chunk_attention(
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [
chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))