侧边栏壁纸
博主头像
分享你我博主等级

行动起来,活在当下

  • 累计撰写 112 篇文章
  • 累计创建 13 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

golang使用koanf读取yaml文件并监听文件更新

管理员
2024-12-24 / 0 评论 / 0 点赞 / 12 阅读 / 10736 字

监听文件更新时更新指定值

支持通过结构体指定标签过滤更新

var mu sync.RWMutex

// Cfg 是全局配置
var Cfg atomic.Value // 使用 atomic.Value 来存储配置

func InitConfig() {
	newCfg := &models.Config{}
	// 创建 Koanf 实例
	k := koanf.New(".")
	proxy := file.Provider("conf/proxy.yaml")
	// 加载第一个配置文件
	if err := k.Load(proxy, yaml.Parser()); err != nil {
		log.Fatalf("Error loading config file 'proxy.yaml': %v", err)
	}

	backup := file.Provider("conf/backup_upstream.yaml")
	// 加载第二个配置文件并合并
	if err := k.Load(backup, yaml.Parser()); err != nil {
		log.Fatalf("Error loading config file 'backup_upstream.yaml': %v", err)
	}
	proxy.Watch(func(event interface{}, err error) {
		if err != nil {
			log.Printf("watch error: %v", err)
			return
		}
		reloadChangedConfig(proxy, true)
	})

	backup.Watch(func(event interface{}, err error) {
		if err != nil {
			log.Printf("watch error: %v", err)
			return
		}
		reloadChangedConfig(backup, false)
	})

	// 将配置解码到结构体中
	if err := k.Unmarshal("", newCfg); err != nil {
		log.Fatalf("Unable to unmarshal into struct: %v", err)
	}
	// 使用 atomic.Value 存储新的配置
	Cfg.Store(newCfg)
	config := GetConfig()
	setDefaults(config, k, "", true)
	fmt.Println("配置文件加载成功")
}

// GetConfig 获取配置
func GetConfig() *models.Config {
	cfg := Cfg.Load()
	if cfg == nil {
		return nil
	}
	return cfg.(*models.Config)
}

// 只更新发生变化的字段
func reloadChangedConfig(p koanf.Provider, isNil bool) {
	// 比较并只更新变化的部分
	mu.Lock()
	defer mu.Unlock()
	k := koanf.New(".")
	if err := k.Load(p, yaml.Parser()); err != nil {
		log.Printf("Error reloading config: %v", err)
		return
	}

	// 创建一个临时配置来解码新的内容
	newCfg := &models.Config{}
	if err := k.Unmarshal("", newCfg); err != nil {
		log.Printf("Error unmarshalling new config: %v", err)
		return
	}

	// 设置默认值(如果需要)
	setDefaults(newCfg, k, "", isNil)
	// 动态比较原始配置和新配置的字段
	updateConfigIfChanged(GetConfig(), newCfg)
}

// 动态比较并更新结构体的变化字段
func updateConfigIfChanged(target, new interface{}) {
	// 使用反射遍历结构体的所有字段
	targetVal := reflect.ValueOf(target).Elem()
	newVal := reflect.ValueOf(new).Elem()

	// 获取原始结构体的类型
	typ := targetVal.Type()

	for i := 0; i < targetVal.NumField(); i++ {
		if !targetVal.IsValid() {
			continue
		}
		targetField := targetVal.Field(i)
		if !newVal.IsValid() {
			continue
		}
		newField := newVal.Field(i)
		noUpdateTag := typ.Field(i).Tag.Get("no_update") // 获取 no_update 标签
		// 确保字段是导出的
		if noUpdateTag != "" || targetVal.Type().Field(i).PkgPath != "" {
			continue // 跳过未导出的字段
		}

		// 如果字段是指针类型,解引用
		if targetField.Kind() == reflect.Ptr {
			// 解引用指针并进行更新
			if targetField.IsNil() {
				continue
			}
			// 递归调用更新嵌套结构体
			updateConfigIfChanged(targetField.Interface(), newField.Interface())
		} else if targetField.Kind() == reflect.Map {
			// 如果字段是 map 类型,遍历 map 中的每个键值对
			updateMapIfChanged(targetField, newField)
		} else if targetField.Kind() == reflect.Struct {
			// 如果字段是嵌套的结构体,递归调用处理
			updateConfigIfChanged(targetField.Addr().Interface(), newField.Addr().Interface())
		} else if targetField.Kind() == reflect.Slice {
			if targetField.IsNil() {
				continue
			}
			if targetField.Len() == 0 {
				targetField.Set(newField)
				continue
			}
			// 如果字段是切片类型,遍历切片中的每个元素
			for j := 0; j < targetField.Len(); j++ {
				if j >= newField.Len() {
					continue
				}
				if j >= newField.Len() {
					continue
				}
				targetElem := targetField.Index(j)
				newElem := newField.Index(j)
				// 确保新元素类型与旧元素类型一致
				if targetElem.Type() != newElem.Type() {
					continue
				}
				// 如果元素是结构体,递归调用处理
				if targetElem.Kind() == reflect.Struct {
					updateConfigIfChanged(targetElem.Addr().Interface(), newElem.Addr().Interface())
				} else if targetElem.Kind() == reflect.Slice {
					updateConfigIfChanged(targetElem.Addr().Interface(), newElem.Addr().Interface())
				} else if targetElem.Kind() == reflect.Map {
					updateConfigIfChanged(targetElem, newElem)
				} else if targetElem.Kind() == reflect.Ptr {
					updateConfigIfChanged(targetElem.Interface(), newElem.Interface())
				} else {
					targetField.Set(newField)
				}
			}
		} else {
			// 如果类型不匹配,直接返回
			if targetField.Type() != newField.Type() {
				continue
			}
			// 打印字段名、旧值和新值
			//fieldName := targetVal.Type().Field(i).Name
			oldValue := targetField.Interface()
			newValue := newField.Interface()
			if !reflect.DeepEqual(oldValue, newValue) {
				// 如果字段值发生变化,更新 target 的相应字段
				targetField.Set(newField)
			}
		}

	}
}

// 比较并更新 map 类型的字段
func updateMapIfChanged(targetField, newField reflect.Value) {
	// 检查两个字段是否都是 map 类型
	if targetField.Kind() != reflect.Map || newField.Kind() != reflect.Map {
		log.Fatalf("Both fields must be maps, got %v and %v", targetField.Kind(), newField.Kind())
		return
	}
	// 获取 map 中的键值类型
	targetMapType := targetField.Type().Elem()
	newMapType := newField.Type().Elem()

	// 如果类型不匹配,直接返回
	if targetMapType != newMapType {
		log.Fatalf("Map value types do not match: %v != %v", targetMapType, newMapType)
		return
	}

	// 处理 map 中的值
	switch targetMapType.Kind() {
	case reflect.Ptr: // 如果是指针类型的 map 元素
		// 获取 map 中的每个键值对
		for _, key := range targetField.MapKeys() {
			// Ensure the key exists in the new map
			newValue := newField.MapIndex(key)
			if !newValue.IsValid() {
				continue
			}

			targetValue := targetField.MapIndex(key)
			if !targetValue.IsValid() {
				continue
			}
			updateConfigIfChanged(targetValue.Interface(), newValue.Interface())
		}
	case reflect.Struct: // 如果是结构体类型的 map 元素
		// 获取 map 中的每个键值对
		for _, key := range targetField.MapKeys() {
			// Ensure the key exists in the new map
			newValue := newField.MapIndex(key)
			if !newValue.IsValid() {
				continue
			}

			targetValue := targetField.MapIndex(key)
			if !targetValue.IsValid() {
				continue
			}
			updateConfigIfChanged(targetValue.Interface(), newValue.Interface())
		}
	default:
		log.Printf("Unsupported map value type: %v", targetMapType.Kind())
	}
}

func setDefaults(v interface{}, k *koanf.Koanf, parentPath string, isNil bool) {
	if v == nil {
		return
	}
	val := reflect.ValueOf(v).Elem() // 获取指向结构体的指针
	if !val.IsValid() {
		// 处理无效值的情况
		return
	}
	typ := val.Type()

	// 遍历结构体的每个字段
	for i := 0; i < val.NumField(); i++ {
		field := val.Field(i)
		fieldType := typ.Field(i)

		// 获取字段标签中的默认值
		defaultValue := fieldType.Tag.Get("default")
		yamlPath := fieldType.Tag.Get("yaml")

		// 如果字段不可设置 (unexported field), 跳过它
		if !field.CanSet() || yamlPath == "" {
			continue
		}
		// Construct the full key path (e.g., "log.pull.enable")
		fullPath := ""
		if parentPath != "" {
			fullPath = parentPath + "." + yamlPath
		} else {
			fullPath = yamlPath
		}
		// 如果是嵌套结构体
		if field.Kind() == reflect.Struct {
			// 递归调用 setDefaults 以设置嵌套结构体的默认值
			setDefaults(field.Addr().Interface(), k, fullPath, isNil)
		} else if field.Kind() == reflect.Slice {
			// 如果字段是 nil,初始化为一个空的切片
			if field.IsNil() && isNil {
				newSlice := reflect.MakeSlice(field.Type(), 0, 0)
				field.Set(newSlice)
			}

			// 获取切片的当前长度
			sliceLen := field.Len()

			// 如果切片为空并且存在默认值,插入默认值
			defaultValue := fieldType.Tag.Get("default")
			if sliceLen == 0 && defaultValue != "" {
				// 为切片元素类型实例化一个元素
				sliceElem := reflect.New(field.Type().Elem()).Elem()

				// 根据字段的默认值类型设置元素值
				if sliceElem.Kind() == reflect.String {
					sliceElem.SetString(defaultValue)
				} else if sliceElem.Kind() == reflect.Int {
					defaultValueInt, err := strconv.Atoi(defaultValue)
					if err != nil {
						log.Fatalf("Failed to parse default value for %s: %v", yamlPath, err)
					}
					sliceElem.SetInt(int64(defaultValueInt))
				} else if sliceElem.Kind() == reflect.Bool {
					defaultValueBool, err := strconv.ParseBool(defaultValue)
					if err != nil {
						log.Fatalf("Failed to parse default value for %s: %v", yamlPath, err)
					}
					sliceElem.SetBool(defaultValueBool)
				}

				// 将新元素追加到切片中
				field.Set(reflect.Append(field, sliceElem))
			}

			// 现在可以安全地访问切片元素
			for i := 0; i < field.Len(); i++ {
				sliceElem := field.Index(i)

				// 如果元素是结构体,递归设置它的默认值
				if sliceElem.Kind() == reflect.Struct {
					setDefaults(sliceElem.Addr().Interface(), k, fullPath, isNil)
				} else if sliceElem.Kind() == reflect.Ptr {
					// 如果切片元素是指针类型并且为 nil,实例化为新的指针
					if sliceElem.IsNil() && isNil {
						newPtr := reflect.New(sliceElem.Type().Elem())
						sliceElem.Set(newPtr)

						// 如果指针指向结构体或 map,递归设置默认值
						if newPtr.Elem().Kind() == reflect.Struct {
							setDefaults(newPtr.Interface(), k, fullPath, isNil)
						} else if newPtr.Elem().Kind() == reflect.Map {
							if newPtr.Elem().IsNil() {
								newPtr.Elem().Set(reflect.MakeMap(newPtr.Elem().Type()))
							}
						}
					} else {
						// 如果指针不为 nil,递归设置它指向的值
						setDefaults(sliceElem.Interface(), k, fullPath, isNil)
					}
				}
			}
		} else if field.Kind() == reflect.Ptr {
			// 如果是指针类型,检查是否为 nil 并且是结构体类型
			if field.IsNil() && isNil {
				// 如果字段为空,实例化为默认的零值
				newPtr := reflect.New(field.Type().Elem())
				field.Set(newPtr)

				// Check if the newly created pointer is a struct or map, and instantiate if needed
				if newPtr.Elem().Kind() == reflect.Struct {
					setDefaults(newPtr.Interface(), k, fullPath, isNil) // Recursively set defaults for the nested struct
				} else if newPtr.Elem().Kind() == reflect.Map {
					if newPtr.Elem().IsNil() {
						newPtr.Elem().Set(reflect.MakeMap(newPtr.Elem().Type()))
					}
				}
			} else {
				// If pointer is not nil, recursively set defaults for the value it points to
				setDefaults(field.Interface(), k, fullPath, isNil)
			}
		} else if field.Kind() == reflect.Map {
			// 如果是 map 类型 (例如 Upstreams),检查是否为 nil 并实例化
			if field.IsNil() && isNil {
				newMap := reflect.MakeMap(field.Type())
				// 实例化为一个空的 map
				field.Set(newMap)
			}
			// 遍历 map 的值
			for _, key := range field.MapKeys() {
				mapValue := field.MapIndex(key)
				if mapValue.Kind() == reflect.Ptr && !mapValue.IsNil() {
					// 对 map 中的每个指针值递归调用 setDefaults
					setDefaults(mapValue.Interface(), k, fullPath, isNil)
				}
			}
		} else if defaultValue != "" && field.IsZero() && !k.Exists(fullPath) {
			// 仅当字段是零值时,才设置默认值
			if field.Kind() == reflect.String {
				field.SetString(defaultValue)
			} else if field.Kind() == reflect.Int {
				defaultValueInt, err := strconv.Atoi(defaultValue)
				if err != nil {
					log.Fatalf("Failed to parse default value for %s: %v", yamlPath, err)
				}
				field.SetInt(int64(defaultValueInt))
			} else if field.Kind() == reflect.Bool {
				defaultValueBool, err := strconv.ParseBool(defaultValue)
				if err != nil {
					log.Fatalf("Failed to parse default value for %s: %v", yamlPath, err)
				}
				field.SetBool(defaultValueBool)
			}
		}
	}
}

0

评论区