diff --git a/main.go b/main.go index 1cea043..266369b 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,12 @@ package main import ( + "database/sql" "fmt" "log" "os" "os/signal" "syscall" - "database/sql" - "github.com/BurntSushi/toml" "github.com/bwmarrin/discordgo" @@ -15,9 +14,9 @@ import ( ) var ( - config Config - client *discordgo.Session - db *sql.DB + config Config + client *discordgo.Session + db *sql.DB ) type Config struct { @@ -38,16 +37,13 @@ type Database struct { Password string `toml:"password"` } - func loadConfig(filename string) (Config, error) { var config Config _, err := toml.DecodeFile(filename, &config) return config, err } -func connectDb(config Config) { - var err error - +func connectDb(config Config) (*sql.DB, error) { connectionString := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", config.Database.Host, @@ -60,14 +56,17 @@ func connectDb(config Config) { db, err := sql.Open("postgres", connectionString) if err != nil { - log.Fatalf("Error connecting to the database: %v", err) + return nil, fmt.Errorf("error connecting to the database: %v", err) } - defer db.Close() err = db.Ping() if err != nil { - log.Fatalf("Error pinging the database: %v", err) + db.Close() + return nil, fmt.Errorf("error pinging the database: %v", err) } + + log.Println("Successfully connected to the database.") + return db, nil } func init() { @@ -84,9 +83,11 @@ func init() { return } - connectDb(config) - - + db, err = connectDb(config) + if err != nil { + log.Println("Error initializing db connection:", err) + return + } } func main() { @@ -103,7 +104,7 @@ func main() { if err != nil { log.Println("Error opening connection:", err) return - } + } stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) @@ -111,5 +112,9 @@ func main() { log.Println("Gracefully shutting down.") client.Close() + + if db != nil { + db.Close() + } }