diff --git a/cmd/bash/main.go b/cmd/bash/main.go index b01505b..0e244a7 100644 --- a/cmd/bash/main.go +++ b/cmd/bash/main.go @@ -1,66 +1,261 @@ package bash import ( - "bufio" - "fmt" - "os" - "strings" - "github.com/cametendo/bash-for-windows/internal/shell" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/chzyer/readline" + + "github.com/cametendo/bash-for-windows/internal/shell" ) func Run() error { - args := os.Args[1:] - - if len(args) > 0 { - if args[0] == "-c" && len(args) > 1 { - // Execute a command string - return runCommand(strings.Join(args[1:], " ")) - } - // Run a script file - return runFile(args[0]) - } - - return interactive() + args := os.Args[1:] + + if len(args) > 0 { + switch args[0] { + case "-c": + if len(args) < 2 { + return fmt.Errorf("-c: option requires an argument") + } + sh := shell.New() + if len(args) > 2 { + sh.SetArgs(args[2:]) + } + return sh.Execute(strings.Join(args[1:], " ")) + case "--version": + fmt.Println("bash-for-windows 2.0.0 (Go-based)") + fmt.Println("Provides bash-compatible shell for Windows.") + return nil + default: + return runFile(args[0], args[1:]) + } + } + + return interactive() } -func runCommand(cmd string) error { - sh := shell.New() - return sh.Execute(cmd) +func runFile(path string, args []string) error { + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("%s: %v", path, err) + } + sh := shell.New() + sh.SetArgs(append([]string{path}, args...)) + sh.SetVar("0", path) + return sh.Execute(string(data)) } -func runFile(path string) error { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("%s: %v", path, err) - } - sh := shell.New() - return sh.Execute(string(data)) +func historyFile() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + // On Windows this lands in %USERPROFILE%\.bash_history + return filepath.Join(home, ".bash_history") } func interactive() error { - sh := shell.New() - reader := bufio.NewReader(os.Stdin) - - fmt.Println("bash-for-windows v1.0.0") - - for { - fmt.Print("bash$ ") - input, err := reader.ReadString('\n') - if err != nil { - break - } - - input = strings.TrimSpace(input) - if input == "" { - continue - } - if input == "exit" { - break - } - - if err := sh.Execute(input); err != nil { - fmt.Fprintf(os.Stderr, "bash: %v\n", err) - } - } - return nil + sh := shell.New() + + // Build completer + completer := readline.NewPrefixCompleter( + readline.PcItemDynamic(func(line string) []string { + return dynamicComplete(sh, line) + }), + ) + + rl, err := readline.NewEx(&readline.Config{ + HistoryFile: historyFile(), + AutoComplete: completer, + InterruptPrompt: "^C", + EOFPrompt: "exit", + HistorySearchFold: true, + FuncFilterInputRune: filterInput, + }) + if err != nil { + // Fall back to dumb interactive mode + return interactiveDumb(sh) + } + defer rl.Close() + + fmt.Fprintln(os.Stdout, "bash-for-windows v2.0.0 (type 'exit' or Ctrl+D to quit)") + + var multiLine strings.Builder + + for { + prompt := buildPrompt(sh) + if multiLine.Len() > 0 { + prompt = "> " + } + rl.SetPrompt(prompt) + + line, err := rl.Readline() + if err != nil { + if err.Error() == "Interrupt" { + // Ctrl+C: abort current multi-line input + multiLine.Reset() + fmt.Fprintln(os.Stdout) + continue + } + break // EOF + } + + if multiLine.Len() > 0 { + multiLine.WriteString("\n") + } + multiLine.WriteString(line) + + input := multiLine.String() + if shell.IsIncomplete(input) { + continue // wait for more input + } + multiLine.Reset() + + input = strings.TrimSpace(input) + if input == "" { + continue + } + + if err := sh.Execute(input); err != nil { + fmt.Fprintf(os.Stderr, "bash: %v\n", err) + } + } + return nil +} + +// interactiveDumb is a fallback that doesn't need readline. +func interactiveDumb(sh *shell.Shell) error { + fmt.Fprintln(os.Stdout, "bash-for-windows v2.0.0") + + var multiLine strings.Builder + buf := make([]byte, 4096) + + for { + prompt := buildPrompt(sh) + if multiLine.Len() > 0 { + prompt = "> " + } + fmt.Fprint(os.Stdout, prompt) + + n, err := os.Stdin.Read(buf) + if n > 0 { + chunk := string(buf[:n]) + if multiLine.Len() > 0 { + multiLine.WriteString("\n") + } + multiLine.WriteString(strings.TrimRight(chunk, "\r\n")) + + input := multiLine.String() + if shell.IsIncomplete(input) { + continue + } + multiLine.Reset() + input = strings.TrimSpace(input) + if input == "" { + continue + } + if execErr := sh.Execute(input); execErr != nil { + fmt.Fprintf(os.Stderr, "bash: %v\n", execErr) + } + } + if err != nil { + break + } + } + return nil +} + +func buildPrompt(sh *shell.Shell) string { + pwd, _ := os.Getwd() + home, _ := os.UserHomeDir() + if home != "" && strings.HasPrefix(pwd, home) { + pwd = "~" + pwd[len(home):] + } + // Show exit code in prompt if non-zero + exitCode := sh.GetVar("?") + suffix := "$ " + if exitCode != "0" && exitCode != "" { + suffix = "[" + exitCode + "]$ " + } + return pwd + suffix +} + +// dynamicComplete provides tab completion for commands and paths. +func dynamicComplete(sh *shell.Shell, line string) []string { + line = strings.TrimLeft(line, " \t") + var completions []string + + // Check if we're completing the first word (command) or an argument (path) + words := strings.Fields(line) + completingCommand := len(words) == 0 || (len(words) == 1 && !strings.HasSuffix(line, " ")) + + if completingCommand { + prefix := "" + if len(words) == 1 { + prefix = words[0] + } + // Builtin/function names — access via the shell instance + // (we can't iterate unexported fields, so use a public method) + for _, name := range sh.BuiltinNames() { + if strings.HasPrefix(name, prefix) { + completions = append(completions, name) + } + } + // PATH executables + for _, dir := range filepath.SplitList(os.Getenv("PATH")) { + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, e := range entries { + name := e.Name() + // Strip .exe on completion display + name = strings.TrimSuffix(name, ".exe") + if strings.HasPrefix(name, prefix) { + completions = append(completions, name) + } + } + } + } else { + // Path completion + prefix := "" + if len(words) > 0 { + prefix = words[len(words)-1] + if strings.HasSuffix(line, " ") { + prefix = "" + } + } + dir := filepath.Dir(prefix) + base := filepath.Base(prefix) + if prefix == "" || strings.HasSuffix(prefix, "/") || strings.HasSuffix(prefix, "\\") { + dir = prefix + base = "" + } + entries, err := os.ReadDir(dir) + if err == nil { + for _, e := range entries { + name := e.Name() + if strings.HasPrefix(name, base) { + p := filepath.Join(dir, name) + if e.IsDir() { + p += "/" + } + completions = append(completions, p) + } + } + } + } + + return completions +} + +func filterInput(r rune) (rune, bool) { + // Block Ctrl+Z (26) — on Windows this would suspend; we handle it gracefully + if r == 26 { + return r, false + } + return r, true } diff --git a/go.mod b/go.mod index 4d79f81..c583dab 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/cametendo/bash-for-windows go 1.26.2 + +require ( + github.com/chzyer/readline v1.5.1 // indirect + golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2de3a80 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/shell/builtins.go b/internal/shell/builtins.go index 35b451c..68a0a8f 100644 --- a/internal/shell/builtins.go +++ b/internal/shell/builtins.go @@ -1,448 +1,1806 @@ package shell import ( - "fmt" - "os" - "path/filepath" - "strings" + "bufio" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "time" ) -func (s *Shell) builtinCd(args []string) error { - path := "" - if len(args) > 0 { - path = args[0] - } else { - path = os.Getenv("HOME") - if path == "" { - path = os.Getenv("USERPROFILE") - } - if path == "" { - path = "." - } - } - - // Handle ~ - if strings.HasPrefix(path, "~") { - home := os.Getenv("HOME") - if home == "" { - home = os.Getenv("USERPROFILE") - } - if len(path) > 1 { - path = home + path[1:] - } else { - path = home - } - } - - // Handle relative paths - if !filepath.IsAbs(path) { - pwd, _ := os.Getwd() - path = filepath.Join(pwd, path) - } - - if err := os.Chdir(path); err != nil { - return fmt.Errorf("cd: %s: %v", path, err) - } - - s.vars["PWD"], _ = os.Getwd() - return nil -} - -func (s *Shell) builtinPwd(args []string) error { - pwd, err := os.Getwd() - if err != nil { - return err - } - fmt.Println(pwd) - return nil -} - -func (s *Shell) builtinEcho(args []string) error { - noNewline := false - escape := false - var parts []string - - for _, arg := range args { - switch arg { - case "-n": - noNewline = true - case "-e": - escape = true - case "-en", "-ne": - noNewline = true - escape = true - default: - if escape { - arg = strings.ReplaceAll(arg, "\\n", "\n") - arg = strings.ReplaceAll(arg, "\\t", "\t") - arg = strings.ReplaceAll(arg, "\\\\", "\\") - } - parts = append(parts, arg) - } - } - - line := strings.Join(parts, " ") - if noNewline { - fmt.Print(line) - } else { - fmt.Println(line) - } - return nil -} - -func (s *Shell) builtinExit(args []string) error { - code := 0 - if len(args) > 0 { - fmt.Sscanf(args[0], "%d", &code) - } - os.Exit(code) - return nil -} - -func (s *Shell) builtinExport(args []string) error { - for _, arg := range args { - parts := strings.SplitN(arg, "=", 2) - if len(parts) == 2 { - s.vars[parts[0]] = parts[1] - os.Setenv(parts[0], parts[1]) - } else if len(parts) == 1 { - // export NAME (mark for export) - if val, ok := s.vars[parts[0]]; ok { - os.Setenv(parts[0], val) - } - } - } - return nil -} - -func (s *Shell) builtinSource(args []string) error { - if len(args) == 0 { - return fmt.Errorf("source: filename argument required") - } - - data, err := os.ReadFile(args[0]) - if err != nil { - return fmt.Errorf("source: %v", err) - } - - return s.Execute(string(data)) -} - var aliases = make(map[string]string) +func (s *Shell) initBuiltins() { + s.builtins = map[string]func([]string) error{ + // Shell builtins + "cd": s.builtinCd, + "pwd": s.builtinPwd, + "echo": s.builtinEcho, + "exit": s.builtinExit, + "export": s.builtinExport, + "source": s.builtinSource, + ".": s.builtinSource, + "alias": s.builtinAlias, + "unalias": s.builtinUnalias, + "type": s.builtinType, + "test": s.builtinTest, + "[": s.builtinTest, + "read": s.builtinRead, + "printf": s.builtinPrintf, + "true": s.builtinTrue, + "false": s.builtinFalse, + "set": s.builtinSet, + "unset": s.builtinUnset, + "env": s.builtinEnv, + "which": s.builtinWhich, + "return": s.builtinReturn, + "break": s.builtinBreak, + "continue": s.builtinContinue, + "shift": s.builtinShift, + "declare": s.builtinDeclare, + "local": s.builtinDeclare, + "command": s.builtinCommand, + "jobs": s.builtinJobs, + // Coreutils + "ls": s.cmdLs, + "cat": s.cmdCat, + "grep": s.cmdGrep, + "head": s.cmdHead, + "tail": s.cmdTail, + "sort": s.cmdSort, + "wc": s.cmdWc, + "find": s.cmdFind, + "cp": s.cmdCp, + "mv": s.cmdMv, + "rm": s.cmdRm, + "mkdir": s.cmdMkdir, + "touch": s.cmdTouch, + "clear": s.cmdClear, + "cut": s.cmdCut, + "tr": s.cmdTr, + "uniq": s.cmdUniq, + "tee": s.cmdTee, + "date": s.cmdDate, + "sleep": s.cmdSleep, + "basename": s.cmdBasename, + "dirname": s.cmdDirname, + "sed": s.cmdSed, + "xargs": s.cmdXargs, + } +} + +// ─── Shell builtins ─────────────────────────────────────────────────────────── + +func (s *Shell) builtinCd(args []string) error { + path := "" + if len(args) > 0 { + path = args[0] + } else { + path = s.GetVar("HOME") + if path == "" { + path = os.Getenv("USERPROFILE") + } + } + path = s.expandWord(path) + if !filepath.IsAbs(path) { + pwd, _ := os.Getwd() + path = filepath.Join(pwd, path) + } + if err := os.Chdir(path); err != nil { + return fmt.Errorf("cd: %s: %v", path, err) + } + s.vars["PWD"], _ = os.Getwd() + return nil +} + +func (s *Shell) builtinPwd(args []string) error { + pwd, err := os.Getwd() + if err != nil { + return err + } + fmt.Fprintln(s.Stdout, pwd) + return nil +} + +func (s *Shell) builtinEcho(args []string) error { + noNewline := false + escape := false + var parts []string + + for _, arg := range args { + switch arg { + case "-n": + noNewline = true + case "-e": + escape = true + case "-E": + escape = false + case "-ne", "-en": + noNewline = true + escape = true + default: + if escape { + arg = strings.ReplaceAll(arg, `\n`, "\n") + arg = strings.ReplaceAll(arg, `\t`, "\t") + arg = strings.ReplaceAll(arg, `\r`, "\r") + arg = strings.ReplaceAll(arg, `\\`, "\\") + arg = strings.ReplaceAll(arg, `\a`, "\a") + arg = strings.ReplaceAll(arg, `\b`, "\b") + } + parts = append(parts, arg) + } + } + line := strings.Join(parts, " ") + if noNewline { + fmt.Fprint(s.Stdout, line) + } else { + fmt.Fprintln(s.Stdout, line) + } + return nil +} + +func (s *Shell) builtinExit(args []string) error { + code := 0 + if len(args) > 0 { + fmt.Sscanf(args[0], "%d", &code) + } + os.Exit(code) + return nil +} + +func (s *Shell) builtinExport(args []string) error { + if len(args) == 0 { + for _, env := range os.Environ() { + fmt.Fprintln(s.Stdout, "export "+env) + } + return nil + } + for _, arg := range args { + parts := strings.SplitN(arg, "=", 2) + if len(parts) == 2 { + s.vars[parts[0]] = parts[1] + os.Setenv(parts[0], parts[1]) + } else if len(parts) == 1 { + if val, ok := s.vars[parts[0]]; ok { + os.Setenv(parts[0], val) + } + } + } + return nil +} + +func (s *Shell) builtinSource(args []string) error { + if len(args) == 0 { + return fmt.Errorf("source: filename argument required") + } + data, err := os.ReadFile(args[0]) + if err != nil { + return fmt.Errorf("source: %v", err) + } + // Pass remaining args as positional params to sourced script + if len(args) > 1 { + s.SetArgs(args[1:]) + } + return s.Execute(string(data)) +} + func (s *Shell) builtinAlias(args []string) error { - if len(args) == 0 { - for name, cmd := range aliases { - fmt.Printf("alias %s='%s'\n", name, cmd) - } - return nil - } - - for _, arg := range args { - parts := strings.SplitN(arg, "=", 2) - if len(parts) == 2 { - name := parts[0] - value := strings.Trim(parts[1], "'\"") - aliases[name] = value - } - } - return nil + if len(args) == 0 { + names := make([]string, 0, len(aliases)) + for name := range aliases { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + fmt.Fprintf(s.Stdout, "alias %s='%s'\n", name, aliases[name]) + } + return nil + } + for _, arg := range args { + parts := strings.SplitN(arg, "=", 2) + if len(parts) == 2 { + aliases[parts[0]] = strings.Trim(parts[1], "'\"") + } else { + if val, ok := aliases[parts[0]]; ok { + fmt.Fprintf(s.Stdout, "alias %s='%s'\n", parts[0], val) + } + } + } + return nil +} + +func (s *Shell) builtinUnalias(args []string) error { + for _, arg := range args { + if arg == "-a" { + aliases = make(map[string]string) + return nil + } + delete(aliases, arg) + } + return nil } func (s *Shell) builtinType(args []string) error { - for _, arg := range args { - if _, ok := s.builtins[arg]; ok { - fmt.Printf("%s is a shell builtin\n", arg) - } else if val, ok := aliases[arg]; ok { - fmt.Printf("%s is aliased to `%s`\n", arg, val) - } else if path := findExecutable(arg); path != "" { - fmt.Printf("%s is %s\n", arg, path) - } else { - fmt.Printf("%s: not found\n", arg) - } - } - return nil + for _, arg := range args { + if _, ok := s.builtins[arg]; ok { + fmt.Fprintf(s.Stdout, "%s is a shell builtin\n", arg) + } else if val, ok := aliases[arg]; ok { + fmt.Fprintf(s.Stdout, "%s is aliased to `%s`\n", arg, val) + } else if _, ok := s.funcs[arg]; ok { + fmt.Fprintf(s.Stdout, "%s is a function\n", arg) + } else if path := findExecutable(arg); path != "" { + fmt.Fprintf(s.Stdout, "%s is %s\n", arg, path) + } else { + fmt.Fprintf(s.Stdout, "%s: not found\n", arg) + } + } + return nil } -func (s *Shell) initBuiltins() { - s.builtins = map[string]func([]string) error{ - "cd": s.builtinCd, - "pwd": s.builtinPwd, - "echo": s.builtinEcho, - "exit": s.builtinExit, - "export": s.builtinExport, - "source": s.builtinSource, - "alias": s.builtinAlias, - "type": s.builtinType, - } +func (s *Shell) builtinTest(args []string) error { + // Strip trailing ] when invoked as [ + if len(args) > 0 && args[len(args)-1] == "]" { + args = args[:len(args)-1] + } + if s.evalTest(args) { + return nil + } + return exitCodeErr{1} } -func (s *Shell) initCommands() { - commands := map[string]func([]string) error{ - "ls": commandLs, - "cat": commandCat, - "grep": commandGrep, - "sort": commandSort, - "wc": commandWc, - "head": commandHead, - "find": commandFind, - "cp": commandCp, - "mv": commandMv, - "rm": commandRm, - "mkdir": commandMkdir, - "touch": commandTouch, - "clear": commandClear, - } - - for name, fn := range commands { - s.builtins[name] = fn - } +func (s *Shell) evalTest(args []string) bool { + if len(args) == 0 { + return false + } + // ! negation + if args[0] == "!" { + return !s.evalTest(args[1:]) + } + // Compound: -a and -o + for i, arg := range args { + if arg == "-a" { + return s.evalTest(args[:i]) && s.evalTest(args[i+1:]) + } + if arg == "-o" { + return s.evalTest(args[:i]) || s.evalTest(args[i+1:]) + } + } + // Unary + if len(args) == 2 { + val := args[1] + switch args[0] { + case "-f": + info, err := os.Stat(val) + return err == nil && info.Mode().IsRegular() + case "-d": + info, err := os.Stat(val) + return err == nil && info.IsDir() + case "-e": + _, err := os.Stat(val) + return err == nil + case "-L", "-h": + _, err := os.Lstat(val) + return err == nil + case "-r": + f, err := os.Open(val) + if err != nil { + return false + } + f.Close() + return true + case "-w": + f, err := os.OpenFile(val, os.O_WRONLY, 0) + if err != nil { + return false + } + f.Close() + return true + case "-x": + info, err := os.Stat(val) + return err == nil && (info.Mode()&0111 != 0 || runtime.GOOS == "windows") + case "-s": + info, err := os.Stat(val) + return err == nil && info.Size() > 0 + case "-z": + return val == "" + case "-n": + return val != "" + } + } + // Binary + if len(args) == 3 { + lhs, op, rhs := args[0], args[1], args[2] + switch op { + case "=", "==": + return lhs == rhs + case "!=": + return lhs != rhs + case "<": + return lhs < rhs + case ">": + return lhs > rhs + case "-eq": + return toInt(lhs) == toInt(rhs) + case "-ne": + return toInt(lhs) != toInt(rhs) + case "-lt": + return toInt(lhs) < toInt(rhs) + case "-le": + return toInt(lhs) <= toInt(rhs) + case "-gt": + return toInt(lhs) > toInt(rhs) + case "-ge": + return toInt(lhs) >= toInt(rhs) + case "-ef": + // Same file + i1, e1 := os.Stat(lhs) + i2, e2 := os.Stat(rhs) + return e1 == nil && e2 == nil && os.SameFile(i1, i2) + case "-nt": + i1, e1 := os.Stat(lhs) + i2, e2 := os.Stat(rhs) + return e1 == nil && e2 == nil && i1.ModTime().After(i2.ModTime()) + case "-ot": + i1, e1 := os.Stat(lhs) + i2, e2 := os.Stat(rhs) + return e1 == nil && e2 == nil && i1.ModTime().Before(i2.ModTime()) + } + } + // Single arg: true if non-empty + if len(args) == 1 { + return args[0] != "" + } + return false } -func commandLs(args []string) error { - path := "." - showAll := false - longFormat := false - - for _, arg := range args { - switch arg { - case "-la", "-al": - showAll = true - longFormat = true - case "-a": - showAll = true - case "-l": - longFormat = true - default: - if !strings.HasPrefix(arg, "-") { - path = arg - } - } - } - - entries, err := os.ReadDir(path) - if err != nil { - return fmt.Errorf("ls: %v", err) - } - - for _, entry := range entries { - name := entry.Name() - if !showAll && strings.HasPrefix(name, ".") { - continue - } - if longFormat { - info, _ := entry.Info() - if info != nil { - mode := info.Mode().String() - size := info.Size() - modTime := info.ModTime().Format("Jan _2 15:04") - fmt.Printf("%s %8d %s %s\n", mode, size, modTime, name) - } - } else { - fmt.Println(name) - } - } - return nil +func toInt(s string) int { + n, _ := strconv.Atoi(strings.TrimSpace(s)) + return n } -func commandCat(args []string) error { - for _, path := range args { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("cat: %s: %v", path, err) - } - fmt.Print(string(data)) - } - return nil +func (s *Shell) builtinRead(args []string) error { + prompt := "" + silent := false + rawMode := false + varNames := []string{"REPLY"} + + i := 0 + for i < len(args) { + switch args[i] { + case "-r": + rawMode = true + case "-s": + silent = true + case "-p": + i++ + if i < len(args) { + prompt = args[i] + } + default: + if !strings.HasPrefix(args[i], "-") { + varNames = args[i:] + i = len(args) + continue + } + } + i++ + } + _ = silent // terminal raw mode for passwords would need platform code + + if prompt != "" { + fmt.Fprint(s.Stderr, prompt) + } + + reader := bufio.NewReader(s.Stdin) + line, err := reader.ReadString('\n') + if err != nil && line == "" { + return err + } + line = strings.TrimRight(line, "\r\n") + if !rawMode { + line = strings.ReplaceAll(line, "\\\n", " ") + } + + if len(varNames) == 1 { + s.vars[varNames[0]] = line + } else { + parts := strings.Fields(line) + for j, name := range varNames { + if j < len(parts) { + if j == len(varNames)-1 { + s.vars[name] = strings.Join(parts[j:], " ") + } else { + s.vars[name] = parts[j] + } + } else { + s.vars[name] = "" + } + } + } + return nil } -func commandGrep(args []string) error { - if len(args) < 1 { - return fmt.Errorf("grep: usage: grep [pattern] [file...]") - } - pattern := args[0] - files := args[1:] - ignoreCase := false - if strings.HasPrefix(pattern, "-i") && len(args) > 1 { - ignoreCase = true - pattern = strings.TrimPrefix(pattern, "-i") - if pattern == "" && len(args) > 1 { - pattern = args[1] - files = args[2:] - } - } +func (s *Shell) builtinPrintf(args []string) error { + if len(args) == 0 { + return nil + } + format := args[0] + fmtArgs := args[1:] - for _, path := range files { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("grep: %s: %v", path, err) - } - lines := strings.Split(string(data), "\n") - for _, line := range lines { - check := line - pat := pattern - if ignoreCase { - check = strings.ToLower(line) - pat = strings.ToLower(pattern) - } - if strings.Contains(check, pat) { - fmt.Println(line) - } - } - } - return nil + var result strings.Builder + argIdx := 0 + i := 0 + for i < len(format) { + if format[i] == '%' && i+1 < len(format) { + i++ + // Optional width/precision + specStart := i + for i < len(format) && (format[i] == '-' || format[i] == '0' || (format[i] >= '1' && format[i] <= '9') || format[i] == '.') { + i++ + } + spec := format[specStart:i] + if i >= len(format) { + break + } + arg := "" + if argIdx < len(fmtArgs) { + arg = fmtArgs[argIdx] + argIdx++ + } + switch format[i] { + case 's': + if spec != "" { + result.WriteString(fmt.Sprintf("%-"+spec+"s", arg)) //nolint + } else { + result.WriteString(arg) + } + case 'd': + n := toInt(arg) + if spec != "" { + result.WriteString(fmt.Sprintf("%"+spec+"d", n)) + } else { + result.WriteString(strconv.Itoa(n)) + } + case 'f': + f, _ := strconv.ParseFloat(arg, 64) + if spec != "" { + result.WriteString(fmt.Sprintf("%"+spec+"f", f)) + } else { + result.WriteString(fmt.Sprintf("%f", f)) + } + case 'x': + n := toInt(arg) + result.WriteString(fmt.Sprintf("%x", n)) + case 'o': + n := toInt(arg) + result.WriteString(fmt.Sprintf("%o", n)) + case '%': + result.WriteByte('%') + argIdx-- // no arg consumed + default: + result.WriteByte('%') + result.WriteByte(format[i]) + argIdx-- + } + i++ + } else if format[i] == '\\' && i+1 < len(format) { + i++ + switch format[i] { + case 'n': + result.WriteByte('\n') + case 't': + result.WriteByte('\t') + case 'r': + result.WriteByte('\r') + case '\\': + result.WriteByte('\\') + case 'a': + result.WriteByte('\a') + case 'b': + result.WriteByte('\b') + default: + result.WriteByte('\\') + result.WriteByte(format[i]) + } + i++ + } else { + result.WriteByte(format[i]) + i++ + } + } + fmt.Fprint(s.Stdout, result.String()) + return nil } -func commandSort(args []string) error { - var lines []string - for _, path := range args { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("sort: %s: %v", path, err) - } - lines = append(lines, strings.Split(string(data), "\n")...) - } - sortStrings(lines) - for _, line := range lines { - fmt.Println(line) - } - return nil +func (s *Shell) builtinTrue(_ []string) error { return nil } +func (s *Shell) builtinFalse(_ []string) error { return exitCodeErr{1} } + +func (s *Shell) builtinSet(args []string) error { + if len(args) == 0 { + keys := make([]string, 0, len(s.vars)) + for k := range s.vars { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + fmt.Fprintf(s.Stdout, "%s=%s\n", k, s.vars[k]) + } + return nil + } + // Handle positional params: set -- a b c + if args[0] == "--" { + s.SetArgs(args[1:]) + } + return nil } -func sortStrings(s []string) { - n := len(s) - for i := 0; i < n; i++ { - for j := i + 1; j < n; j++ { - if s[i] > s[j] { - s[i], s[j] = s[j], s[i] - } - } - } +func (s *Shell) builtinUnset(args []string) error { + for _, arg := range args { + delete(s.vars, arg) + os.Unsetenv(arg) + delete(s.funcs, arg) + delete(s.builtins, arg) + } + return nil } -func commandWc(args []string) error { - if len(args) == 0 { - return fmt.Errorf("wc: stdin not supported yet") - } - totalL, totalW, totalC := 0, 0, 0 - for _, path := range args { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("wc: %s: %v", path, err) - } - content := string(data) - l, w, c := strings.Count(content, "\n"), len(strings.Fields(content)), len(content) - totalL += l; totalW += w; totalC += c - fmt.Printf("%8d %8d %8d %s\n", l, w, c, path) - } - if len(args) > 1 { - fmt.Printf("%8d %8d %8d total\n", totalL, totalW, totalC) - } - return nil +func (s *Shell) builtinEnv(args []string) error { + if len(args) == 0 { + for _, env := range os.Environ() { + fmt.Fprintln(s.Stdout, env) + } + return nil + } + // env VAR=val command + // Just execute the command with the vars set + return s.Execute(strings.Join(args, " ")) } -func commandHead(args []string) error { - n := 10 - files := args - if len(args) > 0 && args[0] == "-n" && len(args) > 1 { - fmt.Sscanf(args[1], "%d", &n) - files = args[2:] - } - for fi, path := range files { - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("head: %s: %v", path, err) - } - if len(files) > 1 { - fmt.Printf("==> %s <==\n", path) - } - splitLines := strings.Split(string(data), "\n") - end := n - if end > len(splitLines) { end = len(splitLines) } - for _, line := range splitLines[:end] { - fmt.Println(line) - } - if fi < len(files)-1 { fmt.Println() } - } - return nil +func (s *Shell) builtinWhich(args []string) error { + found := false + for _, arg := range args { + if p := findExecutable(arg); p != "" { + fmt.Fprintln(s.Stdout, p) + found = true + } else if _, ok := s.builtins[arg]; ok { + fmt.Fprintf(s.Stdout, "%s: shell builtin\n", arg) + found = true + } else { + fmt.Fprintf(s.Stderr, "%s not found\n", arg) + } + } + if !found { + return exitCodeErr{1} + } + return nil } -func commandFind(args []string) error { - root := "." - name := "" - for i, arg := range args { - if arg == "-name" && i+1 < len(args) { - name = args[i+1] - } else if !strings.HasPrefix(arg, "-") && root == "." { - root = arg - } - } - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - if err != nil { return err } - if name != "" { - matched, _ := filepath.Match(name, info.Name()) - if !matched { return nil } - } - fmt.Println(path) - return nil - }) - return err +func (s *Shell) builtinReturn(args []string) error { + code := 0 + if len(args) > 0 { + fmt.Sscanf(args[0], "%d", &code) + } else { + fmt.Sscanf(s.vars["?"], "%d", &code) + } + return returnErr{code} } -func commandCp(args []string) error { - if len(args) < 2 { return fmt.Errorf("cp: missing file operand") } - data, err := os.ReadFile(args[0]) - if err != nil { return fmt.Errorf("cp: %v", err) } - return os.WriteFile(args[1], data, 0644) +func (s *Shell) builtinBreak(args []string) error { + n := 1 + if len(args) > 0 { + fmt.Sscanf(args[0], "%d", &n) + } + return breakErr{n} } -func commandMv(args []string) error { - if len(args) < 2 { return fmt.Errorf("mv: missing file operand") } - return os.Rename(args[0], args[1]) +func (s *Shell) builtinContinue(args []string) error { + n := 1 + if len(args) > 0 { + fmt.Sscanf(args[0], "%d", &n) + } + return continueErr{n} } -func commandRm(args []string) error { - recursive := false - for _, arg := range args { - switch arg { - case "-rf", "-fr", "-r": - recursive = true - default: - if !strings.HasPrefix(arg, "-") { - if recursive { return os.RemoveAll(arg) } - return os.Remove(arg) - } - } - } - return nil +func (s *Shell) builtinShift(args []string) error { + n := 1 + if len(args) > 0 { + fmt.Sscanf(args[0], "%d", &n) + } + if n > len(s.args) { + n = len(s.args) + } + s.SetArgs(s.args[n:]) + return nil } -func commandMkdir(args []string) error { - parents := false - var dirs []string - for _, arg := range args { - if arg == "-p" { parents = true } else { dirs = append(dirs, arg) } - } - for _, dir := range dirs { - if parents { - if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("mkdir: %v", err) } - } else { - if err := os.Mkdir(dir, 0755); err != nil { return fmt.Errorf("mkdir: %v", err) } - } - } - return nil +func (s *Shell) builtinDeclare(args []string) error { + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + continue + } + if idx := strings.Index(arg, "="); idx > 0 { + name := arg[:idx] + if isValidIdentifier(name) { + s.vars[name] = arg[idx+1:] + } + } + } + return nil } -func commandTouch(args []string) error { - for _, path := range args { - f, err := os.Create(path) - if err != nil { return fmt.Errorf("touch: %v", err) } - f.Close() - } - return nil +func (s *Shell) builtinCommand(args []string) error { + if len(args) == 0 { + return nil + } + // Skip aliases/functions, go straight to external + return s.executeExternal(args[0], args[1:]) } -func commandClear(args []string) error { - fmt.Print("\033[H\033[2J") - return nil +func (s *Shell) builtinJobs(_ []string) error { + // Basic stub — full job control would require tracking background pids + return nil +} + +// ─── Coreutils ──────────────────────────────────────────────────────────────── + +func (s *Shell) cmdLs(args []string) error { + path := "." + showAll := false + longFormat := false + humanReadable := false + + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + for _, ch := range arg[1:] { + switch ch { + case 'a', 'A': + showAll = true + case 'l': + longFormat = true + case 'h': + humanReadable = true + } + } + } else { + path = arg + } + } + + entries, err := os.ReadDir(path) + if err != nil { + return fmt.Errorf("ls: %v", err) + } + + for _, entry := range entries { + name := entry.Name() + if !showAll && strings.HasPrefix(name, ".") { + continue + } + if longFormat { + info, _ := entry.Info() + if info != nil { + size := info.Size() + sizeStr := "" + if humanReadable { + sizeStr = humanSize(size) + } else { + sizeStr = fmt.Sprintf("%8d", size) + } + modTime := info.ModTime().Format("Jan _2 15:04") + fmt.Fprintf(s.Stdout, "%s %s %s %s\n", info.Mode(), sizeStr, modTime, name) + } + } else { + fmt.Fprintln(s.Stdout, name) + } + } + return nil +} + +func humanSize(n int64) string { + units := []string{"B", "K", "M", "G", "T"} + f := float64(n) + for _, u := range units { + if f < 1024 { + return fmt.Sprintf("%5.1f%s", f, u) + } + f /= 1024 + } + return fmt.Sprintf("%5.1fP", f) +} + +func (s *Shell) cmdCat(args []string) error { + if len(args) == 0 { + _, err := io.Copy(s.Stdout, s.Stdin) + return err + } + for _, path := range args { + if path == "-" { + io.Copy(s.Stdout, s.Stdin) + continue + } + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("cat: %s: %v", path, err) + } + fmt.Fprint(s.Stdout, string(data)) + } + return nil +} + +func (s *Shell) cmdGrep(args []string) error { + ignoreCase := false + invertMatch := false + lineNumber := false + countOnly := false + quiet := false + var patterns []string + var files []string + + i := 0 + for i < len(args) { + arg := args[i] + if strings.HasPrefix(arg, "-") && arg != "-" { + for _, ch := range arg[1:] { + switch ch { + case 'i': + ignoreCase = true + case 'v': + invertMatch = true + case 'n': + lineNumber = true + case 'c': + countOnly = true + case 'q': + quiet = true + case 'e': + i++ + if i < len(args) { + patterns = append(patterns, args[i]) + } + } + } + } else if len(patterns) == 0 { + patterns = append(patterns, arg) + } else { + files = append(files, arg) + } + i++ + } + + if len(patterns) == 0 { + return fmt.Errorf("grep: no pattern given") + } + + var readers []io.Reader + var names []string + if len(files) == 0 { + readers = append(readers, s.Stdin) + names = append(names, "") + } else { + for _, f := range files { + data, err := os.ReadFile(f) + if err != nil { + fmt.Fprintf(s.Stderr, "grep: %s: %v\n", f, err) + continue + } + readers = append(readers, strings.NewReader(string(data))) + names = append(names, f) + } + } + + matchCount := 0 + for ri, r := range readers { + name := names[ri] + lineNum := 0 + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + lineNum++ + matched := false + for _, pat := range patterns { + check, p := line, pat + if ignoreCase { + check = strings.ToLower(check) + p = strings.ToLower(p) + } + if strings.Contains(check, p) { + matched = true + break + } + } + if invertMatch { + matched = !matched + } + if matched { + matchCount++ + if quiet { + return nil + } + if !countOnly { + prefix := "" + if name != "" && len(files) > 1 { + prefix = name + ":" + } + if lineNumber { + prefix += fmt.Sprintf("%d:", lineNum) + } + fmt.Fprintln(s.Stdout, prefix+line) + } + } + } + if countOnly { + fmt.Fprintln(s.Stdout, matchCount) + } + } + + if quiet && matchCount == 0 { + return exitCodeErr{1} + } + if matchCount == 0 && !countOnly { + return exitCodeErr{1} + } + return nil +} + +func (s *Shell) cmdHead(args []string) error { + n := 10 + var files []string + + i := 0 + for i < len(args) { + switch { + case args[i] == "-n" && i+1 < len(args): + i++ + fmt.Sscanf(args[i], "%d", &n) + case strings.HasPrefix(args[i], "-n"): + fmt.Sscanf(args[i][2:], "%d", &n) + case strings.HasPrefix(args[i], "-") && len(args[i]) > 1: + fmt.Sscanf(args[i][1:], "%d", &n) + default: + files = append(files, args[i]) + } + i++ + } + + readHead := func(r io.Reader, name string) error { + scanner := bufio.NewScanner(r) + count := 0 + for scanner.Scan() { + if count >= n { + break + } + fmt.Fprintln(s.Stdout, scanner.Text()) + count++ + } + return nil + } + + if len(files) == 0 { + return readHead(s.Stdin, "") + } + for _, f := range files { + if len(files) > 1 { + fmt.Fprintf(s.Stdout, "==> %s <==\n", f) + } + fh, err := os.Open(f) + if err != nil { + return fmt.Errorf("head: %s: %v", f, err) + } + readHead(fh, f) + fh.Close() + } + return nil +} + +func (s *Shell) cmdTail(args []string) error { + n := 10 + var files []string + + i := 0 + for i < len(args) { + switch { + case args[i] == "-n" && i+1 < len(args): + i++ + fmt.Sscanf(args[i], "%d", &n) + case strings.HasPrefix(args[i], "-n"): + fmt.Sscanf(args[i][2:], "%d", &n) + case strings.HasPrefix(args[i], "-") && len(args[i]) > 1: + fmt.Sscanf(args[i][1:], "%d", &n) + default: + files = append(files, args[i]) + } + i++ + } + + readTail := func(r io.Reader) error { + scanner := bufio.NewScanner(r) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + start := len(lines) - n + if start < 0 { + start = 0 + } + for _, l := range lines[start:] { + fmt.Fprintln(s.Stdout, l) + } + return nil + } + + if len(files) == 0 { + return readTail(s.Stdin) + } + for _, f := range files { + if len(files) > 1 { + fmt.Fprintf(s.Stdout, "==> %s <==\n", f) + } + fh, err := os.Open(f) + if err != nil { + return fmt.Errorf("tail: %s: %v", f, err) + } + readTail(fh) + fh.Close() + } + return nil +} + +func (s *Shell) cmdSort(args []string) error { + reverse := false + unique := false + numeric := false + var files []string + + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + for _, ch := range arg[1:] { + switch ch { + case 'r': + reverse = true + case 'u': + unique = true + case 'n': + numeric = true + } + } + } else { + files = append(files, arg) + } + } + + var lines []string + readLines := func(r io.Reader) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + } + + if len(files) == 0 { + readLines(s.Stdin) + } else { + for _, f := range files { + fh, err := os.Open(f) + if err != nil { + fmt.Fprintf(s.Stderr, "sort: %s: %v\n", f, err) + continue + } + readLines(fh) + fh.Close() + } + } + + if numeric { + sort.Slice(lines, func(i, j int) bool { + ni := toInt(lines[i]) + nj := toInt(lines[j]) + if reverse { + return ni > nj + } + return ni < nj + }) + } else { + sort.Strings(lines) + if reverse { + for l, r := 0, len(lines)-1; l < r; l, r = l+1, r-1 { + lines[l], lines[r] = lines[r], lines[l] + } + } + } + + seen := map[string]bool{} + for _, l := range lines { + if unique { + if seen[l] { + continue + } + seen[l] = true + } + fmt.Fprintln(s.Stdout, l) + } + return nil +} + +func (s *Shell) cmdWc(args []string) error { + countLines := true + countWords := true + countBytes := true + + var files []string + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + countLines = strings.Contains(arg, "l") + countWords = strings.Contains(arg, "w") + countBytes = strings.Contains(arg, "c") + } else { + files = append(files, arg) + } + } + + doWc := func(r io.Reader, name string) { + data, _ := io.ReadAll(r) + content := string(data) + l := strings.Count(content, "\n") + w := len(strings.Fields(content)) + c := len(data) + out := "" + if countLines { + out += fmt.Sprintf("%8d", l) + } + if countWords { + out += fmt.Sprintf("%8d", w) + } + if countBytes { + out += fmt.Sprintf("%8d", c) + } + if name != "" { + out += " " + name + } + fmt.Fprintln(s.Stdout, out) + } + + if len(files) == 0 { + doWc(s.Stdin, "") + return nil + } + for _, f := range files { + fh, err := os.Open(f) + if err != nil { + fmt.Fprintf(s.Stderr, "wc: %s: %v\n", f, err) + continue + } + doWc(fh, f) + fh.Close() + } + return nil +} + +func (s *Shell) cmdFind(args []string) error { + root := "." + name := "" + typeFlag := "" + maxDepth := -1 + + for i := 0; i < len(args); i++ { + switch args[i] { + case "-name": + if i+1 < len(args) { + i++ + name = args[i] + } + case "-type": + if i+1 < len(args) { + i++ + typeFlag = args[i] + } + case "-maxdepth": + if i+1 < len(args) { + i++ + fmt.Sscanf(args[i], "%d", &maxDepth) + } + default: + if !strings.HasPrefix(args[i], "-") { + root = args[i] + } + } + } + + rootAbs, _ := filepath.Abs(root) + return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + // Check maxdepth + if maxDepth >= 0 { + rel, _ := filepath.Rel(rootAbs, path) + depth := strings.Count(rel, string(os.PathSeparator)) + if depth > maxDepth { + if info.IsDir() { + return filepath.SkipDir + } + return nil + } + } + if name != "" { + matched, _ := filepath.Match(name, info.Name()) + if !matched { + return nil + } + } + if typeFlag != "" { + switch typeFlag { + case "f": + if !info.Mode().IsRegular() { + return nil + } + case "d": + if !info.IsDir() { + return nil + } + case "l": + if info.Mode()&os.ModeSymlink == 0 { + return nil + } + } + } + fmt.Fprintln(s.Stdout, path) + return nil + }) +} + +func (s *Shell) cmdCp(args []string) error { + recursive := false + var sources []string + dest := "" + + for _, arg := range args { + if arg == "-r" || arg == "-R" || arg == "-rf" { + recursive = true + } else { + sources = append(sources, arg) + } + } + if len(sources) < 2 { + return fmt.Errorf("cp: missing destination") + } + dest = sources[len(sources)-1] + sources = sources[:len(sources)-1] + + for _, src := range sources { + info, err := os.Stat(src) + if err != nil { + return fmt.Errorf("cp: %v", err) + } + dstPath := dest + if dstInfo, err := os.Stat(dest); err == nil && dstInfo.IsDir() { + dstPath = filepath.Join(dest, filepath.Base(src)) + } + if info.IsDir() { + if !recursive { + return fmt.Errorf("cp: %s: is a directory (use -r)", src) + } + if err := copyDir(src, dstPath); err != nil { + return fmt.Errorf("cp: %v", err) + } + } else { + if err := copyFile(src, dstPath); err != nil { + return fmt.Errorf("cp: %v", err) + } + } + } + return nil +} + +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err + } + info, _ := os.Stat(src) + perm := os.FileMode(0644) + if info != nil { + perm = info.Mode() + } + return os.WriteFile(dst, data, perm) +} + +func copyDir(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + rel, _ := filepath.Rel(src, path) + target := filepath.Join(dst, rel) + if info.IsDir() { + return os.MkdirAll(target, info.Mode()) + } + return copyFile(path, target) + }) +} + +func (s *Shell) cmdMv(args []string) error { + if len(args) < 2 { + return fmt.Errorf("mv: missing destination") + } + src, dst := args[0], args[1] + if dstInfo, err := os.Stat(dst); err == nil && dstInfo.IsDir() { + dst = filepath.Join(dst, filepath.Base(src)) + } + return os.Rename(src, dst) +} + +func (s *Shell) cmdRm(args []string) error { + recursive := false + force := false + var targets []string + + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + if strings.Contains(arg, "r") || strings.Contains(arg, "R") { + recursive = true + } + if strings.Contains(arg, "f") { + force = true + } + } else { + targets = append(targets, arg) + } + } + + for _, t := range targets { + var err error + if recursive { + err = os.RemoveAll(t) + } else { + err = os.Remove(t) + } + if err != nil && !force { + return fmt.Errorf("rm: %v", err) + } + } + return nil +} + +func (s *Shell) cmdMkdir(args []string) error { + parents := false + var dirs []string + for _, arg := range args { + if arg == "-p" { + parents = true + } else { + dirs = append(dirs, arg) + } + } + for _, dir := range dirs { + var err error + if parents { + err = os.MkdirAll(dir, 0755) + } else { + err = os.Mkdir(dir, 0755) + } + if err != nil { + return fmt.Errorf("mkdir: %v", err) + } + } + return nil +} + +func (s *Shell) cmdTouch(args []string) error { + for _, path := range args { + if _, err := os.Stat(path); err == nil { + now := time.Now() + os.Chtimes(path, now, now) + } else { + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("touch: %v", err) + } + f.Close() + } + } + return nil +} + +func (s *Shell) cmdClear(_ []string) error { + fmt.Fprint(s.Stdout, "\033[H\033[2J") + return nil +} + +func (s *Shell) cmdCut(args []string) error { + delimiter := "\t" + var fields []int + var files []string + + i := 0 + for i < len(args) { + switch { + case args[i] == "-d" && i+1 < len(args): + i++ + delimiter = args[i] + case strings.HasPrefix(args[i], "-d"): + delimiter = args[i][2:] + case args[i] == "-f" && i+1 < len(args): + i++ + for _, part := range strings.Split(args[i], ",") { + if n := toInt(part); n > 0 { + fields = append(fields, n-1) // 0-indexed + } + } + case strings.HasPrefix(args[i], "-f"): + for _, part := range strings.Split(args[i][2:], ",") { + if n := toInt(part); n > 0 { + fields = append(fields, n-1) + } + } + default: + files = append(files, args[i]) + } + i++ + } + + doCut := func(r io.Reader) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + parts := strings.Split(line, delimiter) + var out []string + for _, f := range fields { + if f < len(parts) { + out = append(out, parts[f]) + } + } + fmt.Fprintln(s.Stdout, strings.Join(out, delimiter)) + } + } + + if len(files) == 0 { + doCut(s.Stdin) + return nil + } + for _, f := range files { + fh, err := os.Open(f) + if err != nil { + return fmt.Errorf("cut: %s: %v", f, err) + } + doCut(fh) + fh.Close() + } + return nil +} + +func (s *Shell) cmdTr(args []string) error { + if len(args) < 1 { + return fmt.Errorf("tr: missing operand") + } + deleteMode := false + squeezeMode := false + var rawSet1, rawSet2 string + + i := 0 + for i < len(args) { + switch args[i] { + case "-d": + deleteMode = true + case "-s": + squeezeMode = true + default: + if rawSet1 == "" { + rawSet1 = args[i] + } else { + rawSet2 = args[i] + } + } + i++ + } + + set1 := expandTrSet(rawSet1) + set2 := expandTrSet(rawSet2) + + data, _ := io.ReadAll(s.Stdin) + result := string(data) + + if deleteMode { + set := map[rune]bool{} + for _, ch := range set1 { + set[ch] = true + } + var out strings.Builder + for _, ch := range result { + if !set[ch] { + out.WriteRune(ch) + } + } + result = out.String() + } else if len(set2) > 0 { + var out strings.Builder + prevChar := rune(0) + for _, ch := range result { + idx := -1 + for j, c := range set1 { + if c == ch { + idx = j + break + } + } + if idx >= 0 { + newCh := set2[idx] + if idx >= len(set2) { + newCh = set2[len(set2)-1] // pad with last char + } + if squeezeMode && newCh == prevChar { + continue + } + out.WriteRune(newCh) + prevChar = newCh + } else { + out.WriteRune(ch) + prevChar = ch + } + } + result = out.String() + } + + fmt.Fprint(s.Stdout, result) + return nil +} + +// expandTrSet expands a tr character set string, handling a-z ranges and \n, \t etc. +func expandTrSet(s string) []rune { + var result []rune + i := 0 + runes := []rune(s) + for i < len(runes) { + if runes[i] == '\\' && i+1 < len(runes) { + i++ + switch runes[i] { + case 'n': + result = append(result, '\n') + case 't': + result = append(result, '\t') + case 'r': + result = append(result, '\r') + default: + result = append(result, runes[i]) + } + i++ + } else if i+2 < len(runes) && runes[i+1] == '-' { + // Range like a-z + from, to := runes[i], runes[i+2] + for c := from; c <= to; c++ { + result = append(result, c) + } + i += 3 + } else { + result = append(result, runes[i]) + i++ + } + } + return result +} + +func (s *Shell) cmdUniq(args []string) error { + countMode := false + duplicateMode := false + uniqueMode := false + ignoreCase := false + var files []string + + for _, arg := range args { + if strings.HasPrefix(arg, "-") { + for _, ch := range arg[1:] { + switch ch { + case 'c': + countMode = true + case 'd': + duplicateMode = true + case 'u': + uniqueMode = true + case 'i': + ignoreCase = true + } + } + } else { + files = append(files, arg) + } + } + + doUniq := func(r io.Reader) { + scanner := bufio.NewScanner(r) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + i := 0 + for i < len(lines) { + line := lines[i] + count := 1 + key := line + if ignoreCase { + key = strings.ToLower(key) + } + for i+count < len(lines) { + nextKey := lines[i+count] + if ignoreCase { + nextKey = strings.ToLower(nextKey) + } + if nextKey != key { + break + } + count++ + } + + show := true + if duplicateMode && count == 1 { + show = false + } + if uniqueMode && count > 1 { + show = false + } + if show { + if countMode { + fmt.Fprintf(s.Stdout, "%7d %s\n", count, line) + } else { + fmt.Fprintln(s.Stdout, line) + } + } + i += count + } + } + + if len(files) == 0 { + doUniq(s.Stdin) + return nil + } + for _, f := range files { + fh, err := os.Open(f) + if err != nil { + return fmt.Errorf("uniq: %s: %v", f, err) + } + doUniq(fh) + fh.Close() + } + return nil +} + +func (s *Shell) cmdTee(args []string) error { + appendMode := false + var files []string + + for _, arg := range args { + if arg == "-a" { + appendMode = true + } else { + files = append(files, arg) + } + } + + var writers []io.Writer + writers = append(writers, s.Stdout) + var toClose []io.Closer + + for _, f := range files { + var fh *os.File + var err error + if appendMode { + fh, err = os.OpenFile(f, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + } else { + fh, err = os.Create(f) + } + if err != nil { + fmt.Fprintf(s.Stderr, "tee: %s: %v\n", f, err) + continue + } + writers = append(writers, fh) + toClose = append(toClose, fh) + } + + defer func() { + for _, c := range toClose { + c.Close() + } + }() + + mw := io.MultiWriter(writers...) + _, err := io.Copy(mw, s.Stdin) + return err +} + +func (s *Shell) cmdDate(args []string) error { + format := time.RFC3339 + now := time.Now() + + for _, arg := range args { + if strings.HasPrefix(arg, "+") { + // Convert strftime format to Go format + format = convertDateFormat(arg[1:]) + } + } + + fmt.Fprintln(s.Stdout, now.Format(format)) + return nil +} + +func convertDateFormat(f string) string { + replacements := map[string]string{ + "%Y": "2006", + "%m": "01", + "%d": "02", + "%H": "15", + "%M": "04", + "%S": "05", + "%A": "Monday", + "%a": "Mon", + "%B": "January", + "%b": "Jan", + "%n": "\n", + "%t": "\t", + "%%": "%", + } + for k, v := range replacements { + f = strings.ReplaceAll(f, k, v) + } + return f +} + +func (s *Shell) cmdSleep(args []string) error { + if len(args) == 0 { + return fmt.Errorf("sleep: missing operand") + } + var total time.Duration + for _, arg := range args { + if strings.HasSuffix(arg, "m") { + n, _ := strconv.ParseFloat(arg[:len(arg)-1], 64) + total += time.Duration(n * float64(time.Minute)) + } else if strings.HasSuffix(arg, "h") { + n, _ := strconv.ParseFloat(arg[:len(arg)-1], 64) + total += time.Duration(n * float64(time.Hour)) + } else if strings.HasSuffix(arg, "s") { + n, _ := strconv.ParseFloat(arg[:len(arg)-1], 64) + total += time.Duration(n * float64(time.Second)) + } else { + n, _ := strconv.ParseFloat(arg, 64) + total += time.Duration(n * float64(time.Second)) + } + } + time.Sleep(total) + return nil +} + +func (s *Shell) cmdBasename(args []string) error { + if len(args) == 0 { + return fmt.Errorf("basename: missing operand") + } + result := filepath.Base(args[0]) + if len(args) > 1 { + result = strings.TrimSuffix(result, args[1]) + } + fmt.Fprintln(s.Stdout, result) + return nil +} + +func (s *Shell) cmdDirname(args []string) error { + if len(args) == 0 { + return fmt.Errorf("dirname: missing operand") + } + fmt.Fprintln(s.Stdout, filepath.Dir(args[0])) + return nil +} + +// cmdSed implements a very basic subset of sed: s/pattern/replacement/[g] +func (s *Shell) cmdSed(args []string) error { + var script string + var files []string + inPlace := false + + i := 0 + for i < len(args) { + switch { + case args[i] == "-e" && i+1 < len(args): + i++ + script = args[i] + case args[i] == "-i" || args[i] == "--in-place": + inPlace = true + case strings.HasPrefix(args[i], "-i"): + inPlace = true + case !strings.HasPrefix(args[i], "-") && script == "": + script = args[i] + default: + files = append(files, args[i]) + } + i++ + } + + if script == "" { + return fmt.Errorf("sed: no script") + } + + doSed := func(r io.Reader) (string, error) { + scanner := bufio.NewScanner(r) + var out strings.Builder + for scanner.Scan() { + line := scanner.Text() + line = applySedScript(line, script) + out.WriteString(line + "\n") + } + return out.String(), nil + } + + if len(files) == 0 { + result, err := doSed(s.Stdin) + if err != nil { + return err + } + fmt.Fprint(s.Stdout, result) + return nil + } + + for _, f := range files { + fh, err := os.Open(f) + if err != nil { + return fmt.Errorf("sed: %s: %v", f, err) + } + result, err := doSed(fh) + fh.Close() + if err != nil { + return err + } + if inPlace { + os.WriteFile(f, []byte(result), 0644) + } else { + fmt.Fprint(s.Stdout, result) + } + } + return nil +} + +func applySedScript(line, script string) string { + // Handle: s/pattern/replacement/ or s/pattern/replacement/g + if !strings.HasPrefix(script, "s") { + return line + } + if len(script) < 2 { + return line + } + sep := string(script[1]) + parts := strings.SplitN(script[2:], sep, 3) + if len(parts) < 2 { + return line + } + pattern := parts[0] + replacement := parts[1] + flags := "" + if len(parts) > 2 { + flags = parts[2] + } + if strings.Contains(flags, "g") { + return strings.ReplaceAll(line, pattern, replacement) + } + return strings.Replace(line, pattern, replacement, 1) +} + +func (s *Shell) cmdXargs(args []string) error { + if len(args) == 0 { + args = []string{"echo"} + } + + scanner := bufio.NewScanner(s.Stdin) + var inputArgs []string + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line != "" { + inputArgs = append(inputArgs, strings.Fields(line)...) + } + } + + if len(inputArgs) == 0 { + return nil + } + + return s.executeCommand(strings.Join(args, " ") + " " + strings.Join(inputArgs, " ")) } diff --git a/internal/shell/control.go b/internal/shell/control.go new file mode 100644 index 0000000..77c0f48 --- /dev/null +++ b/internal/shell/control.go @@ -0,0 +1,352 @@ +package shell + +import ( + "fmt" + "strings" +) + +// executeIf handles: if COND; then BODY; [elif COND; then BODY;]* [else BODY;] fi +func (s *Shell) executeIf(block string) error { + stmts := splitStatements(block) + + type branch struct { + cond []string + body []string + } + + var branches []branch + var elseBody []string + + phase := "if_cond" + var curCond []string + var curBody []string + + for _, stmt := range stmts { + w := firstWord(stmt) + rest := afterWord(stmt) + + switch { + case w == "if" && phase == "if_cond": + if rest != "" { + curCond = append(curCond, rest) + } + case w == "then": + if rest != "" { + curBody = append(curBody, rest) + } + phase = "body" + case w == "elif": + branches = append(branches, branch{curCond, curBody}) + curCond = nil + curBody = nil + if rest != "" { + curCond = append(curCond, rest) + } + phase = "elif_cond" + case w == "else": + branches = append(branches, branch{curCond, curBody}) + curCond = nil + curBody = nil + if rest != "" { + elseBody = append(elseBody, rest) + } + phase = "else" + case w == "fi": + switch phase { + case "body": + branches = append(branches, branch{curCond, curBody}) + case "else": + if rest != "" { + elseBody = append(elseBody, rest) + } + } + default: + switch phase { + case "if_cond", "elif_cond": + curCond = append(curCond, stmt) + case "body": + curBody = append(curBody, stmt) + case "else": + elseBody = append(elseBody, stmt) + } + } + } + + for _, b := range branches { + cond := strings.Join(b.cond, "\n") + s.Execute(cond) //nolint — we only care about $? + if s.vars["?"] == "0" { + return s.Execute(strings.Join(b.body, "\n")) + } + } + if len(elseBody) > 0 { + return s.Execute(strings.Join(elseBody, "\n")) + } + return nil +} + +// executeFor handles: for VAR in WORDS; do BODY; done +// (also: for VAR; do BODY; done — iterates positional params) +func (s *Shell) executeFor(block string) error { + stmts := splitStatements(block) + if len(stmts) == 0 { + return nil + } + + // Parse "for VAR in WORDS" + header := stmts[0] + fields := strings.Fields(header) + if len(fields) < 2 { + return fmt.Errorf("for: bad syntax") + } + varName := fields[1] + + var items []string + inIdx := -1 + for i, w := range fields { + if w == "in" { + inIdx = i + break + } + } + if inIdx >= 0 { + for _, raw := range fields[inIdx+1:] { + expanded := s.expandWord(raw) + items = append(items, s.expandGlob(expanded)...) + } + } else { + // for var; do ... → iterate positional params + items = s.args + } + + // Collect body between "do" and "done" + var bodyStmts []string + inBody := false + for _, stmt := range stmts[1:] { + w := firstWord(stmt) + if !inBody { + if w == "do" { + inBody = true + if rest := afterWord(stmt); rest != "" { + bodyStmts = append(bodyStmts, rest) + } + } + continue + } + if w == "done" { + break + } + bodyStmts = append(bodyStmts, stmt) + } + + body := strings.Join(bodyStmts, "\n") + + for _, item := range items { + s.vars[varName] = item + if err := s.Execute(body); err != nil { + if be, ok := err.(breakErr); ok { + if be.n <= 1 { + break + } + return breakErr{be.n - 1} + } + if ce, ok := err.(continueErr); ok { + if ce.n <= 1 { + continue + } + return continueErr{ce.n - 1} + } + if _, ok := err.(returnErr); ok { + return err + } + } + } + return nil +} + +// executeWhileUntil handles while/until loops. +func (s *Shell) executeWhileUntil(block string, isUntil bool) error { + stmts := splitStatements(block) + if len(stmts) == 0 { + return nil + } + + // Parse condition (everything from "while/until COND" up to "do") + var condStmts []string + if rest := afterWord(stmts[0]); rest != "" { + condStmts = append(condStmts, rest) + } + + var bodyStmts []string + inBody := false + for _, stmt := range stmts[1:] { + w := firstWord(stmt) + if !inBody { + if w == "do" { + inBody = true + if rest := afterWord(stmt); rest != "" { + bodyStmts = append(bodyStmts, rest) + } + } else { + condStmts = append(condStmts, stmt) + } + continue + } + if w == "done" { + break + } + bodyStmts = append(bodyStmts, stmt) + } + + cond := strings.Join(condStmts, "\n") + body := strings.Join(bodyStmts, "\n") + + for { + s.Execute(cond) //nolint + condOk := s.vars["?"] == "0" + + if (isUntil && condOk) || (!isUntil && !condOk) { + break + } + + if err := s.Execute(body); err != nil { + if be, ok := err.(breakErr); ok { + if be.n <= 1 { + break + } + return breakErr{be.n - 1} + } + if ce, ok := err.(continueErr); ok { + if ce.n <= 1 { + continue + } + return continueErr{ce.n - 1} + } + if _, ok := err.(returnErr); ok { + return err + } + } + } + return nil +} + +// defineFunction parses and registers a shell function definition. +func (s *Shell) defineFunction(block string) error { + stmts := splitStatements(block) + if len(stmts) == 0 { + return fmt.Errorf("syntax error: empty function") + } + + first := stmts[0] + var name string + + if strings.HasPrefix(first, "function ") { + rest := strings.TrimPrefix(first, "function ") + rest = strings.TrimSpace(rest) + // Strip trailing () and { + rest = strings.TrimSuffix(strings.TrimSpace(rest), "{") + rest = strings.TrimSuffix(strings.TrimSpace(rest), "()") + name = strings.TrimSpace(rest) + } else { + parenIdx := strings.Index(first, "(") + if parenIdx < 0 { + return fmt.Errorf("syntax error: bad function definition") + } + name = strings.TrimSpace(first[:parenIdx]) + } + + if !isValidIdentifier(name) { + return fmt.Errorf("syntax error: invalid function name %q", name) + } + + // Find the opening { in the block — it may be on the same line as the name + // or on a following stmt. Everything after { (up to closing }) is the body. + var bodyStmts []string + inBody := false + + for _, stmt := range stmts { + trimmed := strings.TrimSpace(stmt) + + if !inBody { + // Look for { in this stmt + braceIdx := strings.Index(trimmed, "{") + if braceIdx >= 0 { + inBody = true + rest := strings.TrimSpace(trimmed[braceIdx+1:]) + if rest != "" && rest != "}" { + bodyStmts = append(bodyStmts, rest) + } + // Check if } is also on this line (single-liner like name() { cmd; }) + if strings.HasSuffix(trimmed, "}") && braceIdx < len(trimmed)-1 { + // body is between { and } + inner := strings.TrimSpace(trimmed[braceIdx+1 : len(trimmed)-1]) + bodyStmts = nil + if inner != "" { + bodyStmts = append(bodyStmts, inner) + } + break + } + } + continue + } + + if trimmed == "}" { + break + } + bodyStmts = append(bodyStmts, stmt) + } + + funcBody := strings.Join(bodyStmts, "\n") + s.funcs[name] = funcBody + + funcName := name + s.builtins[funcName] = func(args []string) error { + return s.callFunction(funcName, args) + } + return nil +} + +func (s *Shell) callFunction(name string, args []string) error { + body, ok := s.funcs[name] + if !ok { + return fmt.Errorf("%s: function not found", name) + } + + // Save positional params and exit code + oldArgs := s.args + savedPos := map[string]string{} + for k, v := range s.vars { + if k == "#" || k == "@" || k == "*" || (len(k) == 1 && k[0] >= '1' && k[0] <= '9') { + savedPos[k] = v + } + } + + s.SetArgs(args) + s.vars["?"] = "0" // reset before running body + + err := s.Execute(body) + + // Capture the function's exit code BEFORE restoring params (which might not include ?) + funcExitCode := s.lastExit + + // Restore positional params + s.args = oldArgs + for k, v := range savedPos { + s.vars[k] = v + } + + if re, ok := err.(returnErr); ok { + if re.code != 0 { + return exitCodeErr{re.code} + } + return nil + } + if err != nil { + return err + } + // Propagate last command's exit code from the function body + if funcExitCode != 0 { + return exitCodeErr{funcExitCode} + } + return nil +} diff --git a/internal/shell/exec.go b/internal/shell/exec.go new file mode 100644 index 0000000..acdcb57 --- /dev/null +++ b/internal/shell/exec.go @@ -0,0 +1,375 @@ +package shell + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" +) + +type redirect struct { + fd int // 0=stdin 1=stdout 2=stderr -1=both stdout+stderr + mode string // ">" ">>" "<" + dest string // filename or "&1" "&2" +} + +// executePipeline handles background jobs (&) and pipelines (|). +func (s *Shell) executePipeline(input string) error { + input = strings.TrimSpace(input) + if input == "" { + return nil + } + + // Background job: trailing & (not &&) + bg := false + if strings.HasSuffix(input, "&") && !strings.HasSuffix(input, "&&") { + bg = true + input = strings.TrimSuffix(strings.TrimSuffix(input, "&"), " ") + } + + parts := splitPipe(input) + if len(parts) == 1 { + if bg { + return s.executeCommandBg(strings.TrimSpace(parts[0])) + } + return s.executeCommand(strings.TrimSpace(parts[0])) + } + return s.doPipe(parts, bg) +} + +// splitPipe splits by | but not ||, respecting quotes. +func splitPipe(input string) []string { + var parts []string + current := strings.Builder{} + inSingle := false + inDouble := false + parenDepth := 0 + pendingDollar := false + + for i := 0; i < len(input); i++ { + c := input[i] + switch { + case c == '\'' && !inDouble && parenDepth == 0: + inSingle = !inSingle + current.WriteByte(c) + case c == '"' && !inSingle && parenDepth == 0: + inDouble = !inDouble + current.WriteByte(c) + case c == '$' && !inSingle && !inDouble && i+1 < len(input) && input[i+1] == '(': + pendingDollar = true + current.WriteByte(c) + case c == '(' && !inSingle && !inDouble && (parenDepth > 0 || pendingDollar): + parenDepth++ + pendingDollar = false + current.WriteByte(c) + case c == ')' && !inSingle && !inDouble && parenDepth > 0: + parenDepth-- + current.WriteByte(c) + case c == '|' && !inSingle && !inDouble && parenDepth == 0: + if i+1 < len(input) && input[i+1] == '|' { + current.WriteByte(c) // part of ||, pass through + } else { + parts = append(parts, strings.TrimSpace(current.String())) + current.Reset() + pendingDollar = false + } + default: + pendingDollar = false + current.WriteByte(c) + } + } + if current.Len() > 0 { + parts = append(parts, strings.TrimSpace(current.String())) + } + return parts +} + +// executeCommand executes a single command (no pipes, no &&/||). +func (s *Shell) executeCommand(input string) error { + input = strings.TrimSpace(input) + if input == "" { + return nil + } + + tokens := s.tokenize(input) + if len(tokens) == 0 { + return nil + } + + cmdArgs, redirects := extractRedirects(tokens) + if len(cmdArgs) == 0 { + // Pure redirection, e.g. "> file" creates/truncates file + return s.withRedirects(redirects, func() error { return nil }) + } + + cmdName := cmdArgs[0] + args := cmdArgs[1:] + + // Alias expansion + if alias, ok := aliases[cmdName]; ok { + full := alias + if len(args) > 0 { + full += " " + strings.Join(args, " ") + } + return s.withRedirects(redirects, func() error { + return s.Execute(full) + }) + } + + // Builtin + if builtin, ok := s.builtins[cmdName]; ok { + return s.withRedirects(redirects, func() error { + return builtin(args) + }) + } + + // External + return s.withRedirects(redirects, func() error { + return s.executeExternal(cmdName, args) + }) +} + +func extractRedirects(tokens []string) ([]string, []redirect) { + var args []string + var redirects []redirect + + i := 0 + for i < len(tokens) { + tok := tokens[i] + + switch { + // 2>&1 + case tok == "2>&1": + redirects = append(redirects, redirect{2, ">", "&1"}) + i++ + // 1>&2 + case tok == "1>&2": + redirects = append(redirects, redirect{1, ">", "&2"}) + i++ + // &> or &>> (both stdout+stderr) + case tok == "&>" || tok == "&>>": + if i+1 < len(tokens) { + redirects = append(redirects, redirect{-1, strings.TrimPrefix(tok, "&"), tokens[i+1]}) + i += 2 + } else { + i++ + } + case strings.HasPrefix(tok, "&>"): + mode := ">" + dest := tok[2:] + if strings.HasPrefix(dest, ">") { + mode = ">>" + dest = dest[1:] + } + redirects = append(redirects, redirect{-1, mode, dest}) + i++ + // 2> 2>> 2>file + case tok == "2>" || tok == "2>>": + if i+1 < len(tokens) { + redirects = append(redirects, redirect{2, tok[1:], tokens[i+1]}) + i += 2 + } else { + i++ + } + case strings.HasPrefix(tok, "2>>"): + redirects = append(redirects, redirect{2, ">>", tok[3:]}) + i++ + case strings.HasPrefix(tok, "2>"): + redirects = append(redirects, redirect{2, ">", tok[2:]}) + i++ + // > >> + case tok == ">" || tok == ">>": + if i+1 < len(tokens) { + redirects = append(redirects, redirect{1, tok, tokens[i+1]}) + i += 2 + } else { + i++ + } + case strings.HasPrefix(tok, ">>") && len(tok) > 2: + redirects = append(redirects, redirect{1, ">>", tok[2:]}) + i++ + case strings.HasPrefix(tok, ">") && len(tok) > 1: + redirects = append(redirects, redirect{1, ">", tok[1:]}) + i++ + // < (stdin) + case tok == "<": + if i+1 < len(tokens) { + redirects = append(redirects, redirect{0, "<", tokens[i+1]}) + i += 2 + } else { + i++ + } + case strings.HasPrefix(tok, "<") && len(tok) > 1 && tok[1] != '<': + redirects = append(redirects, redirect{0, "<", tok[1:]}) + i++ + default: + args = append(args, tok) + i++ + } + } + return args, redirects +} + +func (s *Shell) withRedirects(redirects []redirect, fn func() error) error { + if len(redirects) == 0 { + return fn() + } + + oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr + var toClose []io.Closer + + defer func() { + s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr + for _, c := range toClose { + c.Close() + } + }() + + for _, r := range redirects { + switch r.mode { + case ">": + if r.dest == "&1" { + s.Stderr = s.Stdout + } else if r.dest == "&2" { + s.Stdout = s.Stderr + } else { + f, err := os.Create(r.dest) + if err != nil { + return fmt.Errorf("cannot open %s: %v", r.dest, err) + } + toClose = append(toClose, f) + if r.fd == 1 || r.fd == -1 { + s.Stdout = f + } + if r.fd == 2 || r.fd == -1 { + s.Stderr = f + } + } + case ">>": + f, err := os.OpenFile(r.dest, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("cannot open %s: %v", r.dest, err) + } + toClose = append(toClose, f) + if r.fd == 1 || r.fd == -1 { + s.Stdout = f + } + if r.fd == 2 || r.fd == -1 { + s.Stderr = f + } + case "<": + f, err := os.Open(r.dest) + if err != nil { + return fmt.Errorf("cannot open %s: %v", r.dest, err) + } + toClose = append(toClose, f) + s.Stdin = f + } + } + return fn() +} + +func (s *Shell) executeExternal(cmdName string, args []string) error { + path := findExecutable(cmdName) + if path == "" { + fmt.Fprintf(s.Stderr, "%s: command not found\n", cmdName) + return exitCodeErr{127} + } + cmd := exec.Command(path, args...) + cmd.Stdin = s.Stdin + cmd.Stdout = s.Stdout + cmd.Stderr = s.Stderr + err := cmd.Run() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return exitCodeErr{exitErr.ExitCode()} + } + } + return err +} + +func (s *Shell) executeCommandBg(input string) error { + tokens := s.tokenize(input) + if len(tokens) == 0 { + return nil + } + cmdArgs, _ := extractRedirects(tokens) + if len(cmdArgs) == 0 { + return nil + } + path := findExecutable(cmdArgs[0]) + if path == "" { + return fmt.Errorf("%s: command not found", cmdArgs[0]) + } + cmd := exec.Command(path, cmdArgs[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + return err + } + pid := cmd.Process.Pid + s.vars["!"] = fmt.Sprintf("%d", pid) + fmt.Fprintf(s.Stderr, "[1] %d\n", pid) + go func() { cmd.Wait() }() + return nil +} + +func findExecutable(name string) string { + // Direct path + if strings.ContainsAny(name, "/\\") { + if info, err := os.Stat(name); err == nil && !info.IsDir() { + abs, _ := filepath.Abs(name) + return abs + } + return "" + } + path := os.Getenv("PATH") + for _, dir := range filepath.SplitList(path) { + for _, candidate := range []string{ + filepath.Join(dir, name), + filepath.Join(dir, name+".exe"), + filepath.Join(dir, name+".cmd"), + filepath.Join(dir, name+".bat"), + } { + if info, err := os.Stat(candidate); err == nil && !info.IsDir() { + return candidate + } + } + } + return "" +} + +// doPipe executes a pipeline where each stage feeds into the next. +func (s *Shell) doPipe(commands []string, bg bool) error { + _ = bg // background pipe support would require goroutines; skip for now + + var prevBuf []byte + + for i, cmd := range commands { + isLast := i == len(commands)-1 + var stdinReader io.Reader + + if i == 0 { + stdinReader = s.Stdin + } else { + stdinReader = bytes.NewReader(prevBuf) + } + + if isLast { + return s.withIO(stdinReader, nil, nil, func() error { + return s.executeCommand(cmd) + }) + } + + var buf bytes.Buffer + s.withIO(stdinReader, &buf, nil, func() error { + return s.executeCommand(cmd) + }) + prevBuf = buf.Bytes() + } + return nil +} diff --git a/internal/shell/expand.go b/internal/shell/expand.go new file mode 100644 index 0000000..9ba70b7 --- /dev/null +++ b/internal/shell/expand.go @@ -0,0 +1,448 @@ +package shell + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +func isVarChar(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' +} + +// expandWord expands $VAR, ${VAR}, $(...), $((...)) in a single token. +// Quote characters (single/double) are interpreted here and stripped from output. +func (s *Shell) expandWord(word string) string { + // Tilde expansion (only when not quoted) + if strings.HasPrefix(word, "~") { + home := s.GetVar("HOME") + if home == "" { + home = os.Getenv("USERPROFILE") + } + if len(word) == 1 { + return home + } + if word[1] == '/' || word[1] == '\\' { + return home + word[1:] + } + } + + var result strings.Builder + inSingle := false + inDouble := false + i := 0 + + for i < len(word) { + ch := word[i] + + switch { + case ch == '\'' && !inDouble: + inSingle = !inSingle + i++ + case ch == '"' && !inSingle: + inDouble = !inDouble + i++ + case ch == '\\' && !inSingle: + if i+1 < len(word) { + next := word[i+1] + if inDouble { + // In double quotes, only certain chars are escaped + switch next { + case '$', '`', '"', '\\', '\n': + result.WriteByte(next) + i += 2 + default: + result.WriteByte('\\') + result.WriteByte(next) + i += 2 + } + } else { + result.WriteByte(next) + i += 2 + } + } else { + i++ + } + case ch == '$' && !inSingle: + i++ // skip $ + if i >= len(word) { + result.WriteByte('$') + break + } + switch word[i] { + case '(': + if i+1 < len(word) && word[i+1] == '(' { + // $(( arithmetic )) + j := i + 2 + depth := 2 + for j < len(word) { + if word[j] == '(' { + depth++ + } + if word[j] == ')' { + depth-- + if depth == 0 { + j++ + break + } + } + j++ + } + expr := word[i+2 : j-2] + result.WriteString(strconv.Itoa(s.evalArith(expr))) + i = j + } else { + // $( command substitution ) + j := i + 1 + depth := 1 + for j < len(word) { + if word[j] == '(' { + depth++ + } + if word[j] == ')' { + depth-- + if depth == 0 { + break + } + } + j++ + } + cmd := word[i+1 : j] + out := s.captureCommand(cmd) + result.WriteString(strings.TrimRight(out, "\n")) + i = j + 1 + } + case '{': + j := i + 1 + depth := 1 + for j < len(word) { + if word[j] == '{' { + depth++ + } + if word[j] == '}' { + depth-- + if depth == 0 { + break + } + } + j++ + } + varExpr := word[i+1 : j] + result.WriteString(s.evalVarExpr(varExpr)) + i = j + 1 + case '?': + result.WriteString(s.vars["?"]) + i++ + case '$': + result.WriteString(fmt.Sprintf("%d", os.Getpid())) + i++ + case '!': + result.WriteString(s.vars["!"]) + i++ + case '#': + result.WriteString(s.vars["#"]) + i++ + case '@': + result.WriteString(s.vars["@"]) + i++ + case '*': + result.WriteString(s.vars["*"]) + i++ + default: + j := i + for j < len(word) && isVarChar(word[j]) { + j++ + } + if j == i { + result.WriteByte('$') + } else { + result.WriteString(s.getVar(word[i:j])) + i = j + } + } + default: + result.WriteByte(ch) + i++ + } + } + return result.String() +} + +func (s *Shell) getVar(name string) string { + if v, ok := s.vars[name]; ok { + return v + } + return os.Getenv(name) +} + +func (s *Shell) evalVarExpr(expr string) string { + // ${#VAR} — string length + if strings.HasPrefix(expr, "#") { + return strconv.Itoa(len(s.getVar(expr[1:]))) + } + // ${VAR:-default} + if idx := strings.Index(expr, ":-"); idx >= 0 { + varName := expr[:idx] + if v := s.getVar(varName); v != "" { + return v + } + return s.expandWord(expr[idx+2:]) + } + // ${VAR:=default} + if idx := strings.Index(expr, ":="); idx >= 0 { + varName := expr[:idx] + if v := s.getVar(varName); v != "" { + return v + } + expanded := s.expandWord(expr[idx+2:]) + s.vars[varName] = expanded + return expanded + } + // ${VAR:+alt} + if idx := strings.Index(expr, ":+"); idx >= 0 { + varName := expr[:idx] + if v := s.getVar(varName); v != "" { + return s.expandWord(expr[idx+2:]) + } + return "" + } + // ${VAR%pattern} — strip shortest suffix + if idx := strings.Index(expr, "%"); idx >= 0 { + varName := expr[:idx] + pattern := expr[idx+1:] + v := s.getVar(varName) + if strings.HasSuffix(v, pattern) { + return v[:len(v)-len(pattern)] + } + return v + } + // ${VAR#pattern} — strip shortest prefix + if idx := strings.Index(expr, "#"); idx >= 0 { + varName := expr[:idx] + pattern := expr[idx+1:] + v := s.getVar(varName) + if strings.HasPrefix(v, pattern) { + return v[len(pattern):] + } + return v + } + return s.getVar(expr) +} + +// captureCommand runs a command and returns its stdout as a string. +func (s *Shell) captureCommand(cmd string) string { + var buf bytes.Buffer + s.withIO(nil, &buf, nil, func() error { + return s.Execute(cmd) + }) + return buf.String() +} + +// evalArith evaluates a shell arithmetic expression. +func (s *Shell) evalArith(expr string) int { + expr = strings.TrimSpace(s.expandWord(expr)) + // Expand bare variable names (e.g. i+1 → value_of_i + 1) + expr = s.expandArithVars(expr) + return evalArithExpr(expr) +} + +// expandArithVars replaces bare identifier names with their shell variable values. +func (s *Shell) expandArithVars(expr string) string { + var result strings.Builder + i := 0 + for i < len(expr) { + c := expr[i] + // Identifier start (letter or _), but not a digit + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_' { + j := i + for j < len(expr) && isVarChar(expr[j]) { + j++ + } + varName := expr[i:j] + val := s.getVar(varName) + if val == "" { + val = "0" + } + result.WriteString(val) + i = j + } else { + result.WriteByte(c) + i++ + } + } + return result.String() +} + +func evalArithExpr(expr string) int { + expr = strings.TrimSpace(expr) + if n, err := strconv.Atoi(expr); err == nil { + return n + } + // Strip outer parens + if strings.HasPrefix(expr, "(") && strings.HasSuffix(expr, ")") { + return evalArithExpr(expr[1 : len(expr)-1]) + } + // Operators in precedence order (lowest first so we split on last occurrence) + for _, op := range []string{"+", "-", "*", "/", "%"} { + if idx := findBinaryOp(expr, op); idx >= 0 { + left := evalArithExpr(expr[:idx]) + right := evalArithExpr(expr[idx+1:]) + switch op { + case "+": + return left + right + case "-": + return left - right + case "*": + return left * right + case "/": + if right == 0 { + return 0 + } + return left / right + case "%": + if right == 0 { + return 0 + } + return left % right + } + } + } + return 0 +} + +func findBinaryOp(expr, op string) int { + depth := 0 + for i := len(expr) - 1; i >= 0; i-- { + switch expr[i] { + case ')': + depth++ + case '(': + depth-- + } + if depth != 0 { + continue + } + if expr[i:i+1] == op { + if (op == "-" || op == "+") && i == 0 { + continue + } + return i + } + } + return -1 +} + +// expandGlob expands glob patterns; returns original if no match. +func (s *Shell) expandGlob(word string) []string { + if !strings.ContainsAny(word, "*?[") { + return []string{word} + } + matches, err := filepath.Glob(word) + if err != nil || len(matches) == 0 { + return []string{word} + } + return matches +} + +// tokenize splits input into tokens, expands variables, handles quotes and globs. +func (s *Shell) tokenize(input string) []string { + var rawTokens []string + current := strings.Builder{} + inSingle := false + inDouble := false + parenDepth := 0 // nesting depth inside $(...) or $((...)) + pendingDollar := false // true after $ when next char is ( + wasQuoted := false + + flush := func() { + if current.Len() > 0 { + tok := current.String() + if wasQuoted { + tok = "\x00q" + tok + } + rawTokens = append(rawTokens, tok) + current.Reset() + wasQuoted = false + pendingDollar = false + } + } + + for i := 0; i < len(input); i++ { + c := input[i] + switch { + case c == '\'' && !inDouble && parenDepth == 0: + inSingle = !inSingle + wasQuoted = true + current.WriteByte(c) + case c == '"' && !inSingle && parenDepth == 0: + inDouble = !inDouble + wasQuoted = true + current.WriteByte(c) + case c == '$' && !inSingle && i+1 < len(input) && (input[i+1] == '(' || input[i+1] == '{'): + // Mark that the next ( opens a substitution — don't increment depth here + if input[i+1] == '(' { + pendingDollar = true + } + current.WriteByte(c) + case c == '(' && !inSingle && !inDouble && (parenDepth > 0 || pendingDollar): + parenDepth++ + pendingDollar = false + current.WriteByte(c) + case c == ')' && !inSingle && !inDouble && parenDepth > 0: + parenDepth-- + current.WriteByte(c) + case c == '{' && !inSingle && !inDouble && parenDepth > 0: + parenDepth++ + current.WriteByte(c) + case c == '}' && !inSingle && !inDouble && parenDepth > 0: + parenDepth-- + current.WriteByte(c) + case (c == ' ' || c == '\t') && !inSingle && !inDouble && parenDepth == 0: + flush() + case c == '#' && !inSingle && !inDouble && parenDepth == 0 && current.Len() == 0: + // Inline comment: # at start of a new token — discard the rest of the input + goto doneTokenizing + default: + pendingDollar = false + current.WriteByte(c) + } + } +doneTokenizing: + flush() + + // Handle variable assignment on token[0]: FOO=bar + if len(rawTokens) > 0 { + tok := rawTokens[0] + clean := strings.TrimPrefix(tok, "\x00q") + if eqIdx := strings.Index(clean, "="); eqIdx > 0 { + name := clean[:eqIdx] + if isValidIdentifier(name) && !strings.Contains(clean[:eqIdx], "$") { + value := s.expandWord(clean[eqIdx+1:]) + s.vars[name] = value + os.Setenv(name, value) + rawTokens = rawTokens[1:] + if len(rawTokens) == 0 { + return nil + } + } + } + } + + var result []string + for _, tok := range rawTokens { + quoted := strings.HasPrefix(tok, "\x00q") + if quoted { + tok = tok[2:] + } + expanded := s.expandWord(tok) + if !quoted && strings.ContainsAny(expanded, "*?[") { + result = append(result, s.expandGlob(expanded)...) + } else { + result = append(result, expanded) + } + } + return result +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0e5cfb1..0055260 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -1,448 +1,509 @@ package shell import ( - "bytes" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" + "fmt" + "io" + "os" + "strings" ) +// Sentinel errors for control flow +type breakErr struct{ n int } +type continueErr struct{ n int } +type returnErr struct{ code int } + +func (e breakErr) Error() string { return fmt.Sprintf("break %d", e.n) } +func (e continueErr) Error() string { return fmt.Sprintf("continue %d", e.n) } +func (e returnErr) Error() string { return fmt.Sprintf("return %d", e.code) } + +// exitCodeErr carries a non-zero exit code that sets $? without a message. +// Used by functions, test, false, etc. +type exitCodeErr struct{ code int } + +func (e exitCodeErr) Error() string { return "" } + type Shell struct { - vars map[string]string - builtins map[string]func([]string) error - lastExit int + vars map[string]string + builtins map[string]func([]string) error + funcs map[string]string // function name → body + lastExit int + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer + args []string } func New() *Shell { - s := &Shell{ - vars: map[string]string{}, - lastExit: 0, - } - s.initBuiltins() - s.vars["SHELL"] = "bash-for-windows" - s.vars["BASH_VERSION"] = "1.0.0" - s.vars["?"] = "0" - if pwd, err := os.Getwd(); err == nil { - s.vars["PWD"] = pwd - } - if home, err := os.UserHomeDir(); err == nil { - s.vars["HOME"] = home - } - return s + s := &Shell{ + vars: map[string]string{}, + funcs: map[string]string{}, + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + } + s.initBuiltins() + s.vars["SHELL"] = "bash-for-windows" + s.vars["BASH_VERSION"] = "5.2.15(1)-release" + s.vars["?"] = "0" + s.vars["#"] = "0" + s.vars["@"] = "" + s.vars["*"] = "" + s.vars["!"] = "" + if pwd, err := os.Getwd(); err == nil { + s.vars["PWD"] = pwd + } + if home, err := os.UserHomeDir(); err == nil { + s.vars["HOME"] = home + } + return s } -func (s *Shell) Execute(input string) error { - input = strings.TrimSpace(input) - if input == "" { - return nil - } +func (s *Shell) SetArgs(args []string) { + s.args = args + s.vars["#"] = fmt.Sprintf("%d", len(args)) + s.vars["@"] = strings.Join(args, " ") + s.vars["*"] = strings.Join(args, " ") + for i, a := range args { + s.vars[fmt.Sprintf("%d", i+1)] = a + } +} - lines := strings.Split(input, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - if err := s.executeLine(line); err != nil { - return err - } - } - return nil +func (s *Shell) GetVar(name string) string { + if v, ok := s.vars[name]; ok { + return v + } + return os.Getenv(name) +} + +func (s *Shell) SetVar(name, value string) { + s.vars[name] = value +} + +// Execute runs commands from the given input string. +func (s *Shell) Execute(input string) error { + input = strings.ReplaceAll(input, "\\\n", " ") + blocks := parseBlocks(input) + for _, block := range blocks { + if err := s.executeBlock(block); err != nil { + switch err.(type) { + case breakErr, continueErr, returnErr: + return err + } + s.setExitCode(err) + } + } + return nil +} + +// IsIncomplete returns true if the input is an incomplete multi-line construct. +func IsIncomplete(input string) bool { + stmts := splitStatements(input) + depth := 0 + inSingle := false + inDouble := false + for _, ch := range input { + switch ch { + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + } + } + if inSingle || inDouble { + return true + } + for _, stmt := range stmts { + w := firstWord(stmt) + switch w { + case "if", "for", "while", "until": + depth++ + case "fi", "done", "esac": + depth-- + case "{": + depth++ + case "}": + if depth > 0 { + depth-- + } + } + if isFuncDefStart(stmt) && strings.HasSuffix(strings.TrimSpace(stmt), "{") { + depth++ + } + } + return depth > 0 +} + +// parseBlocks groups statements into logical execution units. +// Multi-line if/for/while/function blocks are gathered into single entries. +func parseBlocks(input string) []string { + stmts := splitStatements(input) + var blocks []string + var current []string + kwDepth := 0 // if/for/while/until → fi/done nesting + inFunc := false + funcKwDepth := 0 // keyword nesting inside a function body + + for _, stmt := range stmts { + stmt = strings.TrimSpace(stmt) + if stmt == "" || strings.HasPrefix(stmt, "#") { + continue + } + w := firstWord(stmt) + + if !inFunc { + // Detect function definition opening with `{` + if isFuncDefStart(stmt) && strings.Contains(stmt, "{") { + braceIdx := strings.Index(stmt, "{") + // Count keywords after the { on this same line + funcKwDepth = 0 + for _, p := range splitStatements(stmt[braceIdx+1:]) { + switch firstWord(p) { + case "if", "for", "while", "until": + funcKwDepth++ + case "fi", "done", "esac": + funcKwDepth-- + } + } + // If the line also ends with } it's a self-contained function + if strings.HasSuffix(stmt, "}") { + current = append(current, stmt) + blocks = append(blocks, strings.Join(current, "\n")) + current = nil + funcKwDepth = 0 + continue + } + inFunc = true + current = append(current, stmt) + continue + } + + switch w { + case "if", "for", "while", "until": + kwDepth++ + } + kwDepth += embeddedKwDepth(stmt) + current = append(current, stmt) + switch w { + case "fi", "done", "esac": + kwDepth-- + case "}": + if kwDepth > 0 { + kwDepth-- + } + } + if kwDepth <= 0 && len(current) > 0 { + kwDepth = 0 + blocks = append(blocks, strings.Join(current, "\n")) + current = nil + } + } else { + // Inside function body — watch for } at funcKwDepth==0 + if w == "}" && funcKwDepth <= 0 { + current = append(current, stmt) + blocks = append(blocks, strings.Join(current, "\n")) + current = nil + inFunc = false + funcKwDepth = 0 + continue + } + switch w { + case "if", "for", "while", "until": + funcKwDepth++ + case "fi", "done", "esac": + funcKwDepth-- + } + funcKwDepth += embeddedKwDepth(stmt) + current = append(current, stmt) + } + } + if len(current) > 0 { + blocks = append(blocks, strings.Join(current, "\n")) + } + return blocks +} + +// embeddedKwDepth returns the net depth change from keywords that appear +// after do/then/else/elif within a single statement (excluding the first word, +// which is handled separately by the caller). +func embeddedKwDepth(stmt string) int { + words := strings.Fields(stmt) + delta := 0 + for j := 1; j < len(words); j++ { + switch words[j-1] { + case "do", "then", "else", "elif": + switch words[j] { + case "if", "for", "while", "until": + delta++ + case "fi", "done", "esac": + delta-- + } + } + } + return delta +} + +// splitStatements splits input on semicolons and newlines, respecting quotes. +func splitStatements(input string) []string { + var result []string + current := strings.Builder{} + inSingle := false + inDouble := false + + for i := 0; i < len(input); i++ { + c := input[i] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + current.WriteByte(c) + case c == '"' && !inSingle: + inDouble = !inDouble + current.WriteByte(c) + case (c == ';' || c == '\n') && !inSingle && !inDouble: + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + current.Reset() + default: + current.WriteByte(c) + } + } + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + return result +} + +func firstWord(s string) string { + fields := strings.Fields(s) + if len(fields) == 0 { + return "" + } + return fields[0] +} + +func afterWord(s string) string { + for i, ch := range s { + if ch == ' ' || ch == '\t' { + return strings.TrimSpace(s[i:]) + } + } + return "" +} + +func isFuncDefStart(stmt string) bool { + if strings.HasPrefix(stmt, "function ") { + return true + } + for i, ch := range stmt { + if ch == ' ' || ch == '\t' { + break + } + if ch == '(' { + name := strings.TrimSpace(stmt[:i]) + return isValidIdentifier(name) + } + } + return false +} + +func isValidIdentifier(s string) bool { + if len(s) == 0 { + return false + } + for i, c := range s { + if i == 0 { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return false + } + } else { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return false + } + } + } + return true +} + +func (s *Shell) executeBlock(block string) error { + block = strings.TrimSpace(block) + if block == "" || strings.HasPrefix(block, "#") { + return nil + } + w := firstWord(block) + switch w { + case "if": + return s.executeIf(block) + case "for": + return s.executeFor(block) + case "while": + return s.executeWhileUntil(block, false) + case "until": + return s.executeWhileUntil(block, true) + } + if isFuncDefStart(block) { + return s.defineFunction(block) + } + for _, line := range strings.Split(block, "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if err := s.executeLine(line); err != nil { + return err + } + } + return nil } func (s *Shell) executeLine(line string) error { - // Tokenize the line into its components, handling &&, ||, ; - return s.executeChain(line) + return s.executeChain(line) } -// executeChain: parse && / || / ; with left-to-right precedence func (s *Shell) executeChain(line string) error { - // Strategy: split by ; first (semicolons always separate), - // then by &&/|| within each segment - segments := splitBySemicolon(line) - - for _, seg := range segments { - seg = strings.TrimSpace(seg) - if seg == "" { - continue - } - if err := s.executeAndOrList(seg); err != nil { - // In ; chains, errors in one command don't stop execution - s.setExitCode(err) - } - } - return nil + for _, seg := range splitBySemicolon(line) { + seg = strings.TrimSpace(seg) + if seg == "" { + continue + } + if err := s.executeAndOrList(seg); err != nil { + switch err.(type) { + case breakErr, continueErr, returnErr: + return err + } + s.setExitCode(err) + } + } + return nil } func splitBySemicolon(line string) []string { - var parts []string - current := "" - inSingle := false - inDouble := false - - for i := 0; i < len(line); i++ { - c := line[i] - switch { - case c == '\'' && !inDouble: - inSingle = !inSingle - current += string(c) - case c == '"' && !inSingle: - inDouble = !inDouble - current += string(c) - case c == ';' && !inSingle && !inDouble: - parts = append(parts, current) - current = "" - default: - current += string(c) - } - } - if current != "" { - parts = append(parts, current) - } - return parts + var parts []string + current := strings.Builder{} + inSingle := false + inDouble := false + + for i := 0; i < len(line); i++ { + c := line[i] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + current.WriteByte(c) + case c == '"' && !inSingle: + inDouble = !inDouble + current.WriteByte(c) + case c == ';' && !inSingle && !inDouble: + parts = append(parts, current.String()) + current.Reset() + default: + current.WriteByte(c) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts } -// executeAndOrList: parse && / || with left-to-right precedence func (s *Shell) executeAndOrList(line string) error { - type token struct { - text string - op string // operator BEFORE this token (except first = "") - } + type tok struct { + text string + op string + } + var tokens []tok + current := strings.Builder{} + op := "" + inSingle := false + inDouble := false - var tokens []token - current := "" - op := "" - inSingle := false - inDouble := false + for i := 0; i < len(line); i++ { + c := line[i] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + current.WriteByte(c) + case c == '"' && !inSingle: + inDouble = !inDouble + current.WriteByte(c) + case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&': + tokens = append(tokens, tok{current.String(), op}) + current.Reset() + op = "&&" + i++ + case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|': + tokens = append(tokens, tok{current.String(), op}) + current.Reset() + op = "||" + i++ + default: + current.WriteByte(c) + } + } + if current.Len() > 0 { + tokens = append(tokens, tok{current.String(), op}) + } - for i := 0; i < len(line); i++ { - c := line[i] - switch { - case c == '\'' && !inDouble: - inSingle = !inSingle - current += string(c) - case c == '"' && !inSingle: - inDouble = !inDouble - current += string(c) - case c == '&' && !inSingle && !inDouble: - if i+1 < len(line) && line[i+1] == '&' { - if current != "" { - tokens = append(tokens, token{current, op}) - current = "" - } - op = "&&" - i++ - } else { - current += string(c) - } - case c == '|' && !inSingle && !inDouble: - if i+1 < len(line) && line[i+1] == '|' { - if current != "" { - tokens = append(tokens, token{current, op}) - current = "" - } - op = "||" - i++ - } else { - current += string(c) - } - default: - current += string(c) - } - } - if current != "" { - tokens = append(tokens, token{current, op}) - } - - if len(tokens) == 0 { - return nil - } - - var lastErr error - for i, tok := range tokens { - cmd := strings.TrimSpace(tok.text) - if cmd == "" { - continue - } - - shouldRun := true - if i > 0 && tok.op == "&&" { - shouldRun = (lastErr == nil) - } else if i > 0 && tok.op == "||" { - shouldRun = (lastErr != nil) - } - - if shouldRun { - err := s.executePipeline(cmd) - lastErr = err - s.setExitCode(err) - } - } - - return lastErr + var lastErr error + for i, t := range tokens { + cmd := strings.TrimSpace(t.text) + if cmd == "" { + continue + } + run := i == 0 + if i > 0 { + run = (t.op == "&&" && lastErr == nil) || (t.op == "||" && lastErr != nil) + } + if run { + err := s.executePipeline(cmd) + lastErr = err + s.setExitCode(err) + } + } + return lastErr } func (s *Shell) setExitCode(err error) { - if err != nil { - s.vars["?"] = "1" - } else { - s.vars["?"] = "0" - } + if err == nil { + s.vars["?"] = "0" + s.lastExit = 0 + } else if ec, ok := err.(exitCodeErr); ok { + s.vars["?"] = fmt.Sprintf("%d", ec.code) + s.lastExit = ec.code + } else { + s.vars["?"] = "1" + s.lastExit = 1 + } } -func (s *Shell) executePipeline(input string) error { - input = strings.TrimSpace(input) - if input == "" { - return nil - } - - if strings.Contains(input, "|") { - return s.doPipe(input) - } - - return s.executeCommand(input) +// BuiltinNames returns a sorted list of all registered builtin names (for tab completion). +func (s *Shell) BuiltinNames() []string { + names := make([]string, 0, len(s.builtins)+len(s.funcs)) + for k := range s.builtins { + names = append(names, k) + } + for k := range s.funcs { + names = append(names, k) + } + return names } -func (s *Shell) executeCommand(input string) error { - parts := s.tokenize(input) - if len(parts) == 0 { - return nil - } - - cmdName := parts[0] - args := parts[1:] - - if alias, ok := aliases[cmdName]; ok { - fullCmd := alias - if len(args) > 0 { - fullCmd += " " + strings.Join(args, " ") - } - return s.Execute(fullCmd) - } - - if builtin, ok := s.builtins[cmdName]; ok { - return builtin(args) - } - - return s.executeExternal(cmdName, args) -} - -func (s *Shell) tokenize(input string) []string { - var tokens []string - current := "" - inSingle := false - inDouble := false - - for i := 0; i < len(input); i++ { - c := input[i] - - switch { - case c == '\'' && !inDouble: - inSingle = !inSingle - case c == '"' && !inSingle: - inDouble = !inDouble - case (c == ' ' || c == '\t') && !inSingle && !inDouble: - if current != "" { - tokens = append(tokens, current) - current = "" - } - case c == '=' && !inSingle && !inDouble: - current += string(c) - default: - current += string(c) - } - } - - if current != "" { - tokens = append(tokens, current) - } - - for i, tok := range tokens { - if strings.HasPrefix(tok, "$") { - key := tok[1:] - if strings.HasPrefix(key, "{") && strings.HasSuffix(key, "}") { - key = key[1 : len(key)-1] - } - if val, ok := s.vars[key]; ok { - tokens[i] = val - } else if val := os.Getenv(key); val != "" { - tokens[i] = val - } - } - } - - varAssignIdx := -1 - for i, tok := range tokens { - if strings.Contains(tok, "=") && i == 0 { - eqIdx := strings.Index(tok, "=") - if eqIdx > 0 && eqIdx < len(tok)-1 { - name := tok[:eqIdx] - value := tok[eqIdx+1:] - s.vars[name] = value - os.Setenv(name, value) - varAssignIdx = i - } - } - } - if varAssignIdx == 0 && len(tokens) > 1 { - tokens = tokens[1:] - } - - return tokens -} - -func (s *Shell) executeExternal(cmdName string, args []string) error { - cmdPath := findExecutable(cmdName) - if cmdPath == "" { - return fmt.Errorf("%s: command not found", cmdName) - } - - cmd := exec.Command(cmdPath, args...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - return cmd.Run() -} - -func findExecutable(name string) string { - if _, err := os.Stat(name); err == nil { - if info, _ := os.Stat(name); info != nil && !info.IsDir() { - abs, _ := filepath.Abs(name) - return abs - } - } - - path := os.Getenv("PATH") - for _, dir := range filepath.SplitList(path) { - fullPath := filepath.Join(dir, name) - info, err := os.Stat(fullPath) - if err == nil && !info.IsDir() { - return fullPath - } - fullPathExe := fullPath + ".exe" - info, err = os.Stat(fullPathExe) - if err == nil && !info.IsDir() { - return fullPathExe - } - } - return "" -} - -func (s *Shell) doPipe(input string) error { - commands := strings.Split(input, "|") - - type cmdPart struct { - name string - args []string - builtin bool - } - - var parts []cmdPart - for _, part := range commands { - part = strings.TrimSpace(part) - if part == "" { - continue - } - tokens := s.tokenize(part) - if len(tokens) == 0 { - continue - } - name := tokens[0] - _, isBuiltin := s.builtins[name] - parts = append(parts, cmdPart{name, tokens[1:], isBuiltin}) - } - - if len(parts) == 0 { - return nil - } - - var prevOutput []byte - - for i, p := range parts { - var input []byte - if i > 0 { - input = prevOutput - } - - if p.builtin { - output, err := s.captureBuiltin(p.name, p.args, input) - if err != nil { - // Don't return error, let it pass - prevOutput = nil - if i == len(parts)-1 { - return err - } - continue - } - if i == len(parts)-1 { - fmt.Print(string(output)) - } else { - prevOutput = output - } - } else { - cmdPath := findExecutable(p.name) - if cmdPath == "" { - return fmt.Errorf("%s: command not found", p.name) - } - cmd := exec.Command(cmdPath, p.args...) - cmd.Stderr = os.Stderr - - if i == 0 && len(input) == 0 { - cmd.Stdin = os.Stdin - } else if len(input) > 0 { - stdin, _ := cmd.StdinPipe() - go func() { - stdin.Write(input) - stdin.Close() - }() - } - - if i == len(parts)-1 { - cmd.Stdout = os.Stdout - if err := cmd.Run(); err != nil { - return err - } - } else { - output, err := cmd.Output() - if err != nil { - return err - } - prevOutput = output - } - } - } - - return nil -} - -func (s *Shell) captureBuiltin(name string, args []string, input []byte) ([]byte, error) { - oldStdout := os.Stdout - oldStdin := os.Stdin - r, w, _ := os.Pipe() - os.Stdout = w - - if len(input) > 0 { - ir, iw, _ := os.Pipe() - iw.Write(input) - iw.Close() - os.Stdin = ir - defer ir.Close() - } - - fn := s.builtins[name] - err := fn(args) - - w.Close() - os.Stdout = oldStdout - os.Stdin = oldStdin - - var buf bytes.Buffer - io.Copy(&buf, r) - r.Close() - - return buf.Bytes(), err +// withIO temporarily swaps stdin/stdout/stderr, runs fn, then restores. +// Pass nil to leave the corresponding stream unchanged. +func (s *Shell) withIO(stdin io.Reader, stdout io.Writer, stderr io.Writer, fn func() error) error { + oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr + if stdin != nil { + s.Stdin = stdin + } + if stdout != nil { + s.Stdout = stdout + } + if stderr != nil { + s.Stderr = stderr + } + err := fn() + s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr + return err }