diff --git a/internal/cache/contentcache.go b/internal/cache/contentcache.go index fcd0b2c..0af7850 100644 --- a/internal/cache/contentcache.go +++ b/internal/cache/contentcache.go @@ -1,7 +1,7 @@ package cache import ( - "io" + "bytes" "sync" "github.com/dgraph-io/ristretto" @@ -11,13 +11,13 @@ const ( contentCacheSize int64 = 1000000000 ) -type ContentCache[T io.WriteCloser] struct { - m map[string]sync.Mutex +type ContentCache struct { + m map[string]*sync.Mutex cache *ristretto.Cache - load func(id string) (T, error) + load func(id string) (*bytes.Buffer, error) } -func NewContentCache[T io.WriteCloser](load func(id string) (T, error)) (*ContentCache[T], error) { +func NewContentCache(load func(id string) (*bytes.Buffer, error)) (*ContentCache, error) { cache, err := ristretto.NewCache(&ristretto.Config{ NumCounters: 1e7, MaxCost: contentCacheSize, @@ -27,5 +27,24 @@ func NewContentCache[T io.WriteCloser](load func(id string) (T, error)) (*Conten return nil, err } - return &ContentCache[T]{cache: cache, load: load}, nil + return &ContentCache{m: make(map[string]*sync.Mutex), cache: cache, load: load}, nil +} + +func (c *ContentCache) getContent(id string) (*bytes.Buffer, error) { + c.m[id].Lock() + defer c.m[id].Unlock() + + temp, found := c.cache.Get(id) + ce := temp.(*bytes.Buffer) + if found { + return ce, nil + } + + ce, err := c.load(id) + if err != nil { + return bytes.NewBuffer([]byte{}), err + } + + c.cache.Set(id, ce, int64(ce.Len())) + return ce, nil }