package shell import ( "fmt" "io" "os" "strings" ) // Sentinel errors for control flow type breakErr struct{ n int } type continueErr struct{ n int } type returnErr struct{ code int } func (e breakErr) Error() string { return fmt.Sprintf("break %d", e.n) } func (e continueErr) Error() string { return fmt.Sprintf("continue %d", e.n) } func (e returnErr) Error() string { return fmt.Sprintf("return %d", e.code) } // exitCodeErr carries a non-zero exit code that sets $? without a message. // Used by functions, test, false, etc. type exitCodeErr struct{ code int } func (e exitCodeErr) Error() string { return "" } type Shell struct { vars map[string]string builtins map[string]func([]string) error funcs map[string]string // function name → body lastExit int Stdin io.Reader Stdout io.Writer Stderr io.Writer args []string } func New() *Shell { s := &Shell{ vars: map[string]string{}, funcs: map[string]string{}, Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr, } s.initBuiltins() s.vars["SHELL"] = "bash-for-windows" s.vars["BASH_VERSION"] = "5.2.15(1)-release" s.vars["?"] = "0" s.vars["#"] = "0" s.vars["@"] = "" s.vars["*"] = "" s.vars["!"] = "" if pwd, err := os.Getwd(); err == nil { s.vars["PWD"] = pwd } if home, err := os.UserHomeDir(); err == nil { s.vars["HOME"] = home } return s } func (s *Shell) SetArgs(args []string) { s.args = args s.vars["#"] = fmt.Sprintf("%d", len(args)) s.vars["@"] = strings.Join(args, " ") s.vars["*"] = strings.Join(args, " ") for i, a := range args { s.vars[fmt.Sprintf("%d", i+1)] = a } } func (s *Shell) GetVar(name string) string { if v, ok := s.vars[name]; ok { return v } return os.Getenv(name) } func (s *Shell) SetVar(name, value string) { s.vars[name] = value } // Execute runs commands from the given input string. func (s *Shell) Execute(input string) error { input = strings.ReplaceAll(input, "\\\n", " ") blocks := parseBlocks(input) for _, block := range blocks { if err := s.executeBlock(block); err != nil { switch err.(type) { case breakErr, continueErr, returnErr: return err } s.setExitCode(err) } } return nil } // IsIncomplete returns true if the input is an incomplete multi-line construct. func IsIncomplete(input string) bool { stmts := splitStatements(input) depth := 0 inSingle := false inDouble := false for _, ch := range input { switch ch { case '\'': if !inDouble { inSingle = !inSingle } case '"': if !inSingle { inDouble = !inDouble } } } if inSingle || inDouble { return true } for _, stmt := range stmts { w := firstWord(stmt) switch w { case "if", "for", "while", "until": depth++ case "fi", "done", "esac": depth-- case "{": depth++ case "}": if depth > 0 { depth-- } } if isFuncDefStart(stmt) && strings.HasSuffix(strings.TrimSpace(stmt), "{") { depth++ } } return depth > 0 } // parseBlocks groups statements into logical execution units. // Multi-line if/for/while/function blocks are gathered into single entries. func parseBlocks(input string) []string { stmts := splitStatements(input) var blocks []string var current []string kwDepth := 0 // if/for/while/until → fi/done nesting inFunc := false funcKwDepth := 0 // keyword nesting inside a function body for _, stmt := range stmts { stmt = strings.TrimSpace(stmt) if stmt == "" || strings.HasPrefix(stmt, "#") { continue } w := firstWord(stmt) if !inFunc { // Detect function definition opening with `{` if isFuncDefStart(stmt) && strings.Contains(stmt, "{") { braceIdx := strings.Index(stmt, "{") // Count keywords after the { on this same line funcKwDepth = 0 for _, p := range splitStatements(stmt[braceIdx+1:]) { switch firstWord(p) { case "if", "for", "while", "until": funcKwDepth++ case "fi", "done", "esac": funcKwDepth-- } } // If the line also ends with } it's a self-contained function if strings.HasSuffix(stmt, "}") { current = append(current, stmt) blocks = append(blocks, strings.Join(current, "\n")) current = nil funcKwDepth = 0 continue } inFunc = true current = append(current, stmt) continue } switch w { case "if", "for", "while", "until": kwDepth++ } kwDepth += embeddedKwDepth(stmt) current = append(current, stmt) switch w { case "fi", "done", "esac": kwDepth-- case "}": if kwDepth > 0 { kwDepth-- } } if kwDepth <= 0 && len(current) > 0 { kwDepth = 0 blocks = append(blocks, strings.Join(current, "\n")) current = nil } } else { // Inside function body — watch for } at funcKwDepth==0 if w == "}" && funcKwDepth <= 0 { current = append(current, stmt) blocks = append(blocks, strings.Join(current, "\n")) current = nil inFunc = false funcKwDepth = 0 continue } switch w { case "if", "for", "while", "until": funcKwDepth++ case "fi", "done", "esac": funcKwDepth-- } funcKwDepth += embeddedKwDepth(stmt) current = append(current, stmt) } } if len(current) > 0 { blocks = append(blocks, strings.Join(current, "\n")) } return blocks } // embeddedKwDepth returns the net depth change from keywords that appear // after do/then/else/elif within a single statement (excluding the first word, // which is handled separately by the caller). func embeddedKwDepth(stmt string) int { words := strings.Fields(stmt) delta := 0 for j := 1; j < len(words); j++ { switch words[j-1] { case "do", "then", "else", "elif": switch words[j] { case "if", "for", "while", "until": delta++ case "fi", "done", "esac": delta-- } } } return delta } // splitStatements splits input on semicolons and newlines, respecting quotes. func splitStatements(input string) []string { var result []string current := strings.Builder{} inSingle := false inDouble := false for i := 0; i < len(input); i++ { c := input[i] switch { case c == '\'' && !inDouble: inSingle = !inSingle current.WriteByte(c) case c == '"' && !inSingle: inDouble = !inDouble current.WriteByte(c) case (c == ';' || c == '\n') && !inSingle && !inDouble: if s := strings.TrimSpace(current.String()); s != "" { result = append(result, s) } current.Reset() default: current.WriteByte(c) } } if s := strings.TrimSpace(current.String()); s != "" { result = append(result, s) } return result } func firstWord(s string) string { fields := strings.Fields(s) if len(fields) == 0 { return "" } return fields[0] } func afterWord(s string) string { for i, ch := range s { if ch == ' ' || ch == '\t' { return strings.TrimSpace(s[i:]) } } return "" } func isFuncDefStart(stmt string) bool { if strings.HasPrefix(stmt, "function ") { return true } for i, ch := range stmt { if ch == ' ' || ch == '\t' { break } if ch == '(' { name := strings.TrimSpace(stmt[:i]) return isValidIdentifier(name) } } return false } func isValidIdentifier(s string) bool { if len(s) == 0 { return false } for i, c := range s { if i == 0 { if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { return false } } else { if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { return false } } } return true } func (s *Shell) executeBlock(block string) error { block = strings.TrimSpace(block) if block == "" || strings.HasPrefix(block, "#") { return nil } w := firstWord(block) switch w { case "if": return s.executeIf(block) case "for": return s.executeFor(block) case "while": return s.executeWhileUntil(block, false) case "until": return s.executeWhileUntil(block, true) } if isFuncDefStart(block) { return s.defineFunction(block) } for _, line := range strings.Split(block, "\n") { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } if err := s.executeLine(line); err != nil { return err } } return nil } func (s *Shell) executeLine(line string) error { return s.executeChain(line) } func (s *Shell) executeChain(line string) error { for _, seg := range splitBySemicolon(line) { seg = strings.TrimSpace(seg) if seg == "" { continue } if err := s.executeAndOrList(seg); err != nil { switch err.(type) { case breakErr, continueErr, returnErr: return err } s.setExitCode(err) } } return nil } func splitBySemicolon(line string) []string { var parts []string current := strings.Builder{} inSingle := false inDouble := false for i := 0; i < len(line); i++ { c := line[i] switch { case c == '\'' && !inDouble: inSingle = !inSingle current.WriteByte(c) case c == '"' && !inSingle: inDouble = !inDouble current.WriteByte(c) case c == ';' && !inSingle && !inDouble: parts = append(parts, current.String()) current.Reset() default: current.WriteByte(c) } } if current.Len() > 0 { parts = append(parts, current.String()) } return parts } func (s *Shell) executeAndOrList(line string) error { type tok struct { text string op string } var tokens []tok current := strings.Builder{} op := "" inSingle := false inDouble := false for i := 0; i < len(line); i++ { c := line[i] switch { case c == '\'' && !inDouble: inSingle = !inSingle current.WriteByte(c) case c == '"' && !inSingle: inDouble = !inDouble current.WriteByte(c) case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&': tokens = append(tokens, tok{current.String(), op}) current.Reset() op = "&&" i++ case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|': tokens = append(tokens, tok{current.String(), op}) current.Reset() op = "||" i++ default: current.WriteByte(c) } } if current.Len() > 0 { tokens = append(tokens, tok{current.String(), op}) } var lastErr error for i, t := range tokens { cmd := strings.TrimSpace(t.text) if cmd == "" { continue } run := i == 0 if i > 0 { run = (t.op == "&&" && lastErr == nil) || (t.op == "||" && lastErr != nil) } if run { err := s.executePipeline(cmd) lastErr = err s.setExitCode(err) } } return lastErr } func (s *Shell) setExitCode(err error) { if err == nil { s.vars["?"] = "0" s.lastExit = 0 } else if ec, ok := err.(exitCodeErr); ok { s.vars["?"] = fmt.Sprintf("%d", ec.code) s.lastExit = ec.code } else { s.vars["?"] = "1" s.lastExit = 1 } } // BuiltinNames returns a sorted list of all registered builtin names (for tab completion). func (s *Shell) BuiltinNames() []string { names := make([]string, 0, len(s.builtins)+len(s.funcs)) for k := range s.builtins { names = append(names, k) } for k := range s.funcs { names = append(names, k) } return names } // withIO temporarily swaps stdin/stdout/stderr, runs fn, then restores. // Pass nil to leave the corresponding stream unchanged. func (s *Shell) withIO(stdin io.Reader, stdout io.Writer, stderr io.Writer, fn func() error) error { oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr if stdin != nil { s.Stdin = stdin } if stdout != nil { s.Stdout = stdout } if stderr != nil { s.Stderr = stderr } err := fn() s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr return err }