package api import ( "encoding/json" "fmt" "strings" lua "github.com/yuin/gopher-lua" ) // TestResult holds the outcome of a single test assertion. type TestResult struct { Pass bool Message string } // RunLuaTestScript executes a Lua script against a Response. // // Globals available in the script: // // status (number) HTTP status code, e.g. 200 // status_text (string) Full status string, e.g. "200 OK" // body (string) Raw response body // headers (table) Response headers, lower-cased keys // json_body (table) Parsed JSON body, or nil if not JSON // duration_ms (number) Round-trip duration in milliseconds // size (number) Response body size in bytes // // Helper functions: // // pass(message) Record a passing assertion // fail(message) Record a failing assertion (does NOT abort the script) // assert_eq(a, b, msg) fail(msg) if a ~= b, else pass(msg) // assert_status(code) assert_eq(status, code, "status == "..code) // json_decode(str) Parse a JSON string → Lua table func RunLuaTestScript(script string, resp *Response) ([]TestResult, error) { if strings.TrimSpace(script) == "" { return nil, nil } L := lua.NewState(lua.Options{SkipOpenLibs: false}) defer L.Close() var results []TestResult // Inject pass / fail helpers L.SetGlobal("pass", L.NewFunction(func(L *lua.LState) int { msg := L.OptString(1, "") results = append(results, TestResult{Pass: true, Message: msg}) return 0 })) L.SetGlobal("fail", L.NewFunction(func(L *lua.LState) int { msg := L.OptString(1, "") results = append(results, TestResult{Pass: false, Message: msg}) return 0 })) L.SetGlobal("assert_eq", L.NewFunction(func(L *lua.LState) int { a := L.Get(1) b := L.Get(2) msg := L.OptString(3, fmt.Sprintf("%v == %v", a, b)) if lua.LVCanConvToString(a) && lua.LVCanConvToString(b) && lua.LVAsString(a) == lua.LVAsString(b) { results = append(results, TestResult{Pass: true, Message: msg}) } else { results = append(results, TestResult{ Pass: false, Message: fmt.Sprintf("%s (got %v, expected %v)", msg, a, b), }) } return 0 })) L.SetGlobal("assert_status", L.NewFunction(func(L *lua.LState) int { expected := L.CheckInt(1) msg := fmt.Sprintf("status == %d", expected) if resp.StatusCode == expected { results = append(results, TestResult{Pass: true, Message: msg}) } else { results = append(results, TestResult{ Pass: false, Message: fmt.Sprintf("%s (got %d)", msg, resp.StatusCode), }) } return 0 })) // json_decode helper L.SetGlobal("json_decode", L.NewFunction(func(L *lua.LState) int { s := L.CheckString(1) var v interface{} if err := json.Unmarshal([]byte(s), &v); err != nil { L.Push(lua.LNil) L.Push(lua.LString(err.Error())) return 2 } L.Push(goToLua(L, v)) return 1 })) // Set response globals L.SetGlobal("status", lua.LNumber(resp.StatusCode)) L.SetGlobal("status_text", lua.LString(resp.Status)) L.SetGlobal("body", lua.LString(resp.Body)) L.SetGlobal("duration_ms", lua.LNumber(resp.Duration.Milliseconds())) L.SetGlobal("size", lua.LNumber(resp.Size)) // headers table (lower-cased keys) headersTbl := L.NewTable() for k, vals := range resp.Headers { if len(vals) > 0 { L.SetField(headersTbl, strings.ToLower(k), lua.LString(vals[0])) } } L.SetGlobal("headers", headersTbl) // json_body: parse resp.Body if JSON var jsonBody interface{} if resp.IsJSON { if err := json.Unmarshal([]byte(resp.Body), &jsonBody); err == nil { L.SetGlobal("json_body", goToLua(L, jsonBody)) } else { L.SetGlobal("json_body", lua.LNil) } } else { L.SetGlobal("json_body", lua.LNil) } if err := L.DoString(script); err != nil { // Include any results collected before the error errResult := TestResult{ Pass: false, Message: "Lua error: " + err.Error(), } results = append(results, errResult) return results, nil // return results, not the error, so UI can display them } return results, nil } // goToLua converts a Go value (from json.Unmarshal) into a lua.LValue. func goToLua(L *lua.LState, v interface{}) lua.LValue { if v == nil { return lua.LNil } switch val := v.(type) { case bool: return lua.LBool(val) case float64: return lua.LNumber(val) case string: return lua.LString(val) case []interface{}: tbl := L.NewTable() for i, item := range val { tbl.RawSetInt(i+1, goToLua(L, item)) } return tbl case map[string]interface{}: tbl := L.NewTable() for k, item := range val { L.SetField(tbl, k, goToLua(L, item)) } return tbl default: return lua.LString(fmt.Sprintf("%v", val)) } }