Fix database locked errors (#3153)

* Make read-only operations use WithReadTxn
* Allow one database write thread
* Add unit test for concurrent transactions
* Perform some actions after commit to release txn
* Suppress some errors from cancelled context
This commit is contained in:
WithoutPants
2022-11-21 06:49:10 +11:00
committed by GitHub
parent 420c6fa9d7
commit f39fa416a9
54 changed files with 626 additions and 311 deletions

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/fvbommel/sortorder"
@@ -73,7 +72,7 @@ type Database struct {
schemaVersion uint
writeMu sync.Mutex
lockChan chan struct{}
}
func NewDatabase() *Database {
@@ -87,6 +86,7 @@ func NewDatabase() *Database {
Image: NewImageStore(fileStore),
Gallery: NewGalleryStore(fileStore, folderStore),
Performer: NewPerformerStore(),
lockChan: make(chan struct{}, 1),
}
return ret
@@ -106,8 +106,8 @@ func (db *Database) Ready() error {
// necessary migrations must be run separately using RunMigrations.
// Returns true if the database is new.
func (db *Database) Open(dbPath string) error {
db.writeMu.Lock()
defer db.writeMu.Unlock()
db.lockNoCtx()
defer db.unlock()
db.dbPath = dbPath
@@ -152,9 +152,36 @@ func (db *Database) Open(dbPath string) error {
return nil
}
// lock locks the database for writing.
// This method will block until the lock is acquired of the context is cancelled.
func (db *Database) lock(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case db.lockChan <- struct{}{}:
return nil
}
}
// lock locks the database for writing. This method will block until the lock is acquired.
func (db *Database) lockNoCtx() {
db.lockChan <- struct{}{}
}
// unlock unlocks the database
func (db *Database) unlock() {
// will block the caller if the lock is not held, so check first
select {
case <-db.lockChan:
return
default:
panic("database is not locked")
}
}
func (db *Database) Close() error {
db.writeMu.Lock()
defer db.writeMu.Unlock()
db.lockNoCtx()
defer db.unlock()
if db.db != nil {
if err := db.db.Close(); err != nil {

View File

@@ -929,18 +929,15 @@ func Test_imageQueryBuilder_Destroy(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
withRollbackTxn(func(ctx context.Context) error {
if err := qb.Destroy(ctx, tt.id); (err != nil) != tt.wantErr {
t.Errorf("imageQueryBuilder.Destroy() error = %v, wantErr %v", err, tt.wantErr)
}
if err := qb.Destroy(ctx, tt.id); (err != nil) != tt.wantErr {
t.Errorf("imageQueryBuilder.Destroy() error = %v, wantErr %v", err, tt.wantErr)
}
// ensure cannot be found
i, err := qb.Find(ctx, tt.id)
// ensure cannot be found
i, err := qb.Find(ctx, tt.id)
assert.NotNil(err)
assert.Nil(i)
return nil
})
assert.NotNil(err)
assert.Nil(i)
})
}
}

View File

@@ -1398,18 +1398,15 @@ func Test_sceneQueryBuilder_Destroy(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
withRollbackTxn(func(ctx context.Context) error {
if err := qb.Destroy(ctx, tt.id); (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.Destroy() error = %v, wantErr %v", err, tt.wantErr)
}
if err := qb.Destroy(ctx, tt.id); (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.Destroy() error = %v, wantErr %v", err, tt.wantErr)
}
// ensure cannot be found
i, err := qb.Find(ctx, tt.id)
// ensure cannot be found
i, err := qb.Find(ctx, tt.id)
assert.NotNil(err)
assert.Nil(i)
return nil
})
assert.NotNil(err)
assert.Nil(i)
})
}
}
@@ -1477,26 +1474,23 @@ func Test_sceneQueryBuilder_Find(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
withTxn(func(ctx context.Context) error {
got, err := qb.Find(ctx, tt.id)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.Find() error = %v, wantErr %v", err, tt.wantErr)
return nil
got, err := qb.Find(ctx, tt.id)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
// load relationships
if err := loadSceneRelationships(ctx, *tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return
}
if got != nil {
// load relationships
if err := loadSceneRelationships(ctx, *tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return nil
}
clearSceneFileIDs(got)
}
clearSceneFileIDs(got)
}
assert.Equal(tt.want, got)
return nil
})
assert.Equal(tt.want, got)
})
}
}
@@ -1620,23 +1614,19 @@ func Test_sceneQueryBuilder_FindByChecksum(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
withTxn(func(ctx context.Context) error {
assert := assert.New(t)
got, err := qb.FindByChecksum(ctx, tt.checksum)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByChecksum() error = %v, wantErr %v", err, tt.wantErr)
return nil
}
assert := assert.New(t)
got, err := qb.FindByChecksum(ctx, tt.checksum)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByChecksum() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return nil
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return
}
assert.Equal(tt.want, got)
return nil
})
assert.Equal(tt.want, got)
})
}
}
@@ -1694,23 +1684,20 @@ func Test_sceneQueryBuilder_FindByOSHash(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
withTxn(func(ctx context.Context) error {
got, err := qb.FindByOSHash(ctx, tt.oshash)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByOSHash() error = %v, wantErr %v", err, tt.wantErr)
return nil
}
got, err := qb.FindByOSHash(ctx, tt.oshash)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByOSHash() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return nil
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("sceneQueryBuilder.FindByOSHash() = %v, want %v", got, tt.want)
}
return nil
})
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("sceneQueryBuilder.FindByOSHash() = %v, want %v", got, tt.want)
}
})
}
}
@@ -1768,23 +1755,19 @@ func Test_sceneQueryBuilder_FindByPath(t *testing.T) {
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
withTxn(func(ctx context.Context) error {
assert := assert.New(t)
got, err := qb.FindByPath(ctx, tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByPath() error = %v, wantErr %v", err, tt.wantErr)
return nil
}
assert := assert.New(t)
got, err := qb.FindByPath(ctx, tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("sceneQueryBuilder.FindByPath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return nil
}
if err := postFindScenes(ctx, tt.want, got); err != nil {
t.Errorf("loadSceneRelationships() error = %v", err)
return
}
assert.Equal(tt.want, got)
return nil
})
assert.Equal(tt.want, got)
})
}
}

View File

@@ -17,6 +17,7 @@ type key int
const (
txnKey key = iota + 1
dbKey
exclusiveKey
)
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
@@ -28,7 +29,7 @@ func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, dbKey, db.db), nil
}
func (db *Database) Begin(ctx context.Context) (context.Context, error) {
func (db *Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
if tx, _ := getTx(ctx); tx != nil {
// log the stack trace so we can see
logger.Error(string(debug.Stack()))
@@ -36,11 +37,23 @@ func (db *Database) Begin(ctx context.Context) (context.Context, error) {
return nil, fmt.Errorf("already in transaction")
}
if exclusive {
if err := db.lock(ctx); err != nil {
return nil, err
}
}
tx, err := db.db.BeginTxx(ctx, nil)
if err != nil {
// begin failed, unlock
if exclusive {
db.unlock()
}
return nil, fmt.Errorf("beginning transaction: %w", err)
}
ctx = context.WithValue(ctx, exclusiveKey, exclusive)
return context.WithValue(ctx, txnKey, tx), nil
}
@@ -50,6 +63,8 @@ func (db *Database) Commit(ctx context.Context) error {
return err
}
defer db.txnComplete(ctx)
if err := tx.Commit(); err != nil {
return err
}
@@ -63,6 +78,8 @@ func (db *Database) Rollback(ctx context.Context) error {
return err
}
defer db.txnComplete(ctx)
if err := tx.Rollback(); err != nil {
return err
}
@@ -70,6 +87,12 @@ func (db *Database) Rollback(ctx context.Context) error {
return nil
}
func (db *Database) txnComplete(ctx context.Context) {
if exclusive := ctx.Value(exclusiveKey).(bool); exclusive {
db.unlock()
}
}
func getTx(ctx context.Context) (*sqlx.Tx, error) {
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {

View File

@@ -0,0 +1,243 @@
//go:build integration
// +build integration
package sqlite_test
import (
"context"
"sync"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// this test is left commented out as it is not deterministic.
// func TestConcurrentExclusiveTxn(t *testing.T) {
// const (
// workers = 8
// loops = 100
// innerLoops = 10
// sleepTime = 2 * time.Millisecond
// )
// ctx := context.Background()
// var wg sync.WaitGroup
// for k := 0; k < workers; k++ {
// wg.Add(1)
// go func(wk int) {
// for l := 0; l < loops; l++ {
// // change this to WithReadTxn to see locked database error
// if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
// for ll := 0; ll < innerLoops; ll++ {
// scene := &models.Scene{
// Title: "test",
// }
// if err := db.Scene.Create(ctx, scene, nil); err != nil {
// return err
// }
// if err := db.Scene.Destroy(ctx, scene.ID); err != nil {
// return err
// }
// }
// time.Sleep(sleepTime)
// return nil
// }); err != nil {
// t.Errorf("worker %d loop %d: %v", wk, l, err)
// }
// }
// wg.Done()
// }(k)
// }
// wg.Wait()
// }
func TestConcurrentReadTxn(t *testing.T) {
var wg sync.WaitGroup
ctx := context.Background()
c := make(chan struct{}, 1)
// first thread
wg.Add(2)
go func() {
defer wg.Done()
if err := txn.WithReadTxn(ctx, db, func(ctx context.Context) error {
scene := &models.Scene{
Title: "test",
}
if err := db.Scene.Create(ctx, scene, nil); err != nil {
return err
}
// wait for other thread to start
c <- struct{}{}
<-c
if err := db.Scene.Destroy(ctx, scene.ID); err != nil {
return err
}
return nil
}); err != nil {
t.Errorf("unexpected error in first thread: %v", err)
}
}()
// second thread
go func() {
defer wg.Done()
_ = txn.WithReadTxn(ctx, db, func(ctx context.Context) error {
// wait for first thread
<-c
defer func() {
c <- struct{}{}
}()
scene := &models.Scene{
Title: "test",
}
// expect error when we try to do this, as the other thread has already
// modified this table
if err := db.Scene.Create(ctx, scene, nil); err != nil {
if !db.IsLocked(err) {
t.Errorf("unexpected error: %v", err)
}
return err
} else {
t.Errorf("expected locked error in second thread")
}
return nil
})
}()
wg.Wait()
}
func TestConcurrentExclusiveAndReadTxn(t *testing.T) {
var wg sync.WaitGroup
ctx := context.Background()
c := make(chan struct{}, 1)
// first thread
wg.Add(2)
go func() {
defer wg.Done()
if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
scene := &models.Scene{
Title: "test",
}
if err := db.Scene.Create(ctx, scene, nil); err != nil {
return err
}
// wait for other thread to start
c <- struct{}{}
<-c
if err := db.Scene.Destroy(ctx, scene.ID); err != nil {
return err
}
return nil
}); err != nil {
t.Errorf("unexpected error in first thread: %v", err)
}
}()
// second thread
go func() {
defer wg.Done()
_ = txn.WithReadTxn(ctx, db, func(ctx context.Context) error {
// wait for first thread
<-c
defer func() {
c <- struct{}{}
}()
if _, err := db.Scene.Find(ctx, sceneIDs[sceneIdx1WithPerformer]); err != nil {
t.Errorf("unexpected error: %v", err)
return err
}
return nil
})
}()
wg.Wait()
}
// this test is left commented out as it is not deterministic.
// func TestConcurrentExclusiveAndReadTxns(t *testing.T) {
// const (
// writeWorkers = 4
// readWorkers = 4
// loops = 200
// innerLoops = 10
// sleepTime = 1 * time.Millisecond
// )
// ctx := context.Background()
// var wg sync.WaitGroup
// for k := 0; k < writeWorkers; k++ {
// wg.Add(1)
// go func(wk int) {
// for l := 0; l < loops; l++ {
// if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
// for ll := 0; ll < innerLoops; ll++ {
// scene := &models.Scene{
// Title: "test",
// }
// if err := db.Scene.Create(ctx, scene, nil); err != nil {
// return err
// }
// if err := db.Scene.Destroy(ctx, scene.ID); err != nil {
// return err
// }
// }
// time.Sleep(sleepTime)
// return nil
// }); err != nil {
// t.Errorf("write worker %d loop %d: %v", wk, l, err)
// }
// }
// wg.Done()
// }(k)
// }
// for k := 0; k < readWorkers; k++ {
// wg.Add(1)
// go func(wk int) {
// for l := 0; l < loops; l++ {
// if err := txn.WithReadTxn(ctx, db, func(ctx context.Context) error {
// for ll := 0; ll < innerLoops; ll++ {
// if _, err := db.Scene.Find(ctx, sceneIDs[ll%totalScenes]); err != nil {
// return err
// }
// }
// time.Sleep(sleepTime)
// return nil
// }); err != nil {
// t.Errorf("read worker %d loop %d: %v", wk, l, err)
// }
// }
// wg.Done()
// }(k)
// }
// wg.Wait()
// }