Spaces:
Configuration error
Configuration error
| import itertools | |
| from typing import Optional | |
| class TaggedCache: | |
| def __init__(self, tag_settings: Optional[dict]=None): | |
| self._tag_settings = tag_settings or {} # tag cache size | |
| self._data = {} | |
| def __getitem__(self, key): | |
| for tag_data in self._data.values(): | |
| if key in tag_data: | |
| return tag_data[key] | |
| raise KeyError(f'Key `{key}` does not exist') | |
| def __setitem__(self, key, value: tuple): | |
| # value: (tag: str, (islist: bool, data: *)) | |
| # if key already exists, pop old value | |
| for tag_data in self._data.values(): | |
| if key in tag_data: | |
| tag_data.pop(key, None) | |
| break | |
| tag = value[0] | |
| if tag not in self._data: | |
| try: | |
| from cachetools import LRUCache | |
| default_size = 20 | |
| if 'ckpt' in tag: | |
| default_size = 5 | |
| elif tag in ['latent', 'image']: | |
| default_size = 100 | |
| self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size)) | |
| except (ImportError, ModuleNotFoundError): | |
| # TODO: implement a simple lru dict | |
| self._data[tag] = {} | |
| self._data[tag][key] = value | |
| def __delitem__(self, key): | |
| for tag_data in self._data.values(): | |
| if key in tag_data: | |
| del tag_data[key] | |
| return | |
| raise KeyError(f'Key `{key}` does not exist') | |
| def __contains__(self, key): | |
| return any(key in tag_data for tag_data in self._data.values()) | |
| def items(self): | |
| yield from itertools.chain(*map(lambda x :x.items(), self._data.values())) | |
| def get(self, key, default=None): | |
| """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" | |
| for tag_data in self._data.values(): | |
| if key in tag_data: | |
| return tag_data[key] | |
| return default | |
| def clear(self): | |
| # clear all cache | |
| self._data = {} | |
| cache_settings = {} | |
| cache = TaggedCache(cache_settings) | |
| cache_count = {} | |
| def update_cache(k, tag, v): | |
| cache[k] = (tag, v) | |
| cnt = cache_count.get(k) | |
| if cnt is None: | |
| cnt = 0 | |
| cache_count[k] = cnt | |
| else: | |
| cache_count[k] += 1 | |
| def remove_cache(key): | |
| global cache | |
| if key == '*': | |
| cache = TaggedCache(cache_settings) | |
| elif key in cache: | |
| del cache[key] | |
| else: | |
| print(f"invalid {key}") |